Repository: GU-CLASP/TypedFlow Branch: master Commit: 3a8fa230d413 Files: 39 Total size: 272.5 KB Directory structure: gitextract_rnbmpni2/ ├── .gitignore ├── LICENSE ├── Makefile ├── README.org ├── TypedFlow/ │ ├── Abstract.hs │ ├── Broadcast.hs │ ├── Haskell.hs │ ├── Layers/ │ │ ├── Core.hs │ │ ├── RNN/ │ │ │ ├── Attention.hs │ │ │ ├── Base.hs │ │ │ └── Cells.hs │ │ └── RNN.hs │ ├── Layers.hs │ ├── Learn.hs │ ├── Memo.hs │ ├── Memo2.hs │ ├── Models/ │ │ ├── Topic.hs │ │ └── Transformer.hs │ ├── Python.hs │ ├── TF.hs │ ├── Types/ │ │ └── Proofs.hs │ └── Types.hs ├── TypedFlow.hs ├── cabal.project ├── docs/ │ ├── HOT.org │ └── Talk.org ├── examples/ │ ├── agreement/ │ │ └── Aggr.hs │ ├── mnist/ │ │ ├── MNIST.hs │ │ ├── Makefile │ │ ├── main.py │ │ └── mnist_model.py │ └── seq2seq/ │ ├── GenTr.hs │ ├── Makefile │ ├── Seq2Seq.hs │ ├── main.py │ └── shell.nix ├── styx.yaml ├── typedflow.cabal └── typedflow_rts.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .styx *~ dist dist-* cabal-dev *.o *.hi *.chi *.chs.h *.dyn_o *.dyn_hi .hpc .hsenv .cabal-sandbox/ cabal.sandbox.config *.prof *.aux *.hp *.eventlog .stack-work/ cabal.project.local .HTF/ /examples/seq2seq/s2s.py /examples/seq2seq/synthtrees.txt MNIST_data __pycache__ /examples/seq2seq/GenTr /.tramp_history ================================================ FILE: LICENSE ================================================ GNU LESSER GENERAL PUBLIC LICENSE Version 3, 29 June 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. This version of the GNU Lesser General Public License incorporates the terms and conditions of version 3 of the GNU General Public License, supplemented by the additional permissions listed below. 0. Additional Definitions. As used herein, "this License" refers to version 3 of the GNU Lesser General Public License, and the "GNU GPL" refers to version 3 of the GNU General Public License. "The Library" refers to a covered work governed by this License, other than an Application or a Combined Work as defined below. An "Application" is any work that makes use of an interface provided by the Library, but which is not otherwise based on the Library. Defining a subclass of a class defined by the Library is deemed a mode of using an interface provided by the Library. A "Combined Work" is a work produced by combining or linking an Application with the Library. The particular version of the Library with which the Combined Work was made is also called the "Linked Version". The "Minimal Corresponding Source" for a Combined Work means the Corresponding Source for the Combined Work, excluding any source code for portions of the Combined Work that, considered in isolation, are based on the Application, and not on the Linked Version. The "Corresponding Application Code" for a Combined Work means the object code and/or source code for the Application, including any data and utility programs needed for reproducing the Combined Work from the Application, but excluding the System Libraries of the Combined Work. 1. Exception to Section 3 of the GNU GPL. You may convey a covered work under sections 3 and 4 of this License without being bound by section 3 of the GNU GPL. 2. Conveying Modified Versions. If you modify a copy of the Library, and, in your modifications, a facility refers to a function or data to be supplied by an Application that uses the facility (other than as an argument passed when the facility is invoked), then you may convey a copy of the modified version: a) under this License, provided that you make a good faith effort to ensure that, in the event an Application does not supply the function or data, the facility still operates, and performs whatever part of its purpose remains meaningful, or b) under the GNU GPL, with none of the additional permissions of this License applicable to that copy. 3. Object Code Incorporating Material from Library Header Files. The object code form of an Application may incorporate material from a header file that is part of the Library. You may convey such object code under terms of your choice, provided that, if the incorporated material is not limited to numerical parameters, data structure layouts and accessors, or small macros, inline functions and templates (ten or fewer lines in length), you do both of the following: a) Give prominent notice with each copy of the object code that the Library is used in it and that the Library and its use are covered by this License. b) Accompany the object code with a copy of the GNU GPL and this license document. 4. Combined Works. You may convey a Combined Work under terms of your choice that, taken together, effectively do not restrict modification of the portions of the Library contained in the Combined Work and reverse engineering for debugging such modifications, if you also do each of the following: a) Give prominent notice with each copy of the Combined Work that the Library is used in it and that the Library and its use are covered by this License. b) Accompany the Combined Work with a copy of the GNU GPL and this license document. c) For a Combined Work that displays copyright notices during execution, include the copyright notice for the Library among these notices, as well as a reference directing the user to the copies of the GNU GPL and this license document. d) Do one of the following: 0) Convey the Minimal Corresponding Source under the terms of this License, and the Corresponding Application Code in a form suitable for, and under terms that permit, the user to recombine or relink the Application with a modified version of the Linked Version to produce a modified Combined Work, in the manner specified by section 6 of the GNU GPL for conveying Corresponding Source. 1) Use a suitable shared library mechanism for linking with the Library. A suitable mechanism is one that (a) uses at run time a copy of the Library already present on the user's computer system, and (b) will operate properly with a modified version of the Library that is interface-compatible with the Linked Version. e) Provide Installation Information, but only if you would otherwise be required to provide such information under section 6 of the GNU GPL, and only to the extent that such information is necessary to install and execute a modified version of the Combined Work produced by recombining or relinking the Application with a modified version of the Linked Version. (If you use option 4d0, the Installation Information must accompany the Minimal Corresponding Source and Corresponding Application Code. If you use option 4d1, you must provide the Installation Information in the manner specified by section 6 of the GNU GPL for conveying Corresponding Source.) 5. Combined Libraries. You may place library facilities that are a work based on the Library side by side in a single library together with other library facilities that are not Applications and are not covered by this License, and convey such a combined library under terms of your choice, if you do both of the following: a) Accompany the combined library with a copy of the same work based on the Library, uncombined with any other library facilities, conveyed under the terms of this License. b) Give prominent notice with the combined library that part of it is a work based on the Library, and explaining where to find the accompanying uncombined form of the same work. 6. Revised Versions of the GNU Lesser General Public License. The Free Software Foundation may publish revised and/or new versions of the GNU Lesser General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Library as you received it specifies that a certain numbered version of the GNU Lesser General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that published version or of any later version published by the Free Software Foundation. If the Library as you received it does not specify a version number of the GNU Lesser General Public License, you may choose any version of the GNU Lesser General Public License ever published by the Free Software Foundation. If the Library as you received it specifies that a proxy can decide whether future versions of the GNU Lesser General Public License shall apply, that proxy's public statement of acceptance of any version is permanent authorization for you to choose that version for the Library. ================================================ FILE: Makefile ================================================ viewdoc: dist/doc/html/typedflow/index.html xdg-open $< dist/doc/html/typedflow/index.html: styx cabal -- haddock --hyperlink-source styx cabal -- hscolour ================================================ FILE: README.org ================================================ #+TITLE: TypedFlow TypedFlow is a typed, higher-order frontend to [[http://www.tensorflow.org][TensorFlow]] and a high-level library for deep-learning. The main design principles are: - To make the parameters of layers explicit. This choice makes sharing of parameters explicit and allows to implement "layers" as pure functions. - To provide as precise as possible types. Functions are explicit about the shapes and elements of the tensors that they manipulate (they are often polymorphic in shapes and elements though.) - To let combinators be as transparent as possible. If a NN layers is a simple tensor transformation it will be exposed as such. In this version, the interface to TensorFlow is done via python-code generation and a suitable runtime system. ** Documentation The compiled documentation should be found on [[https://hackage.haskell.org/package/typedflow][hackage]]. ** Examples TypedFlow comes with two examples of neural networks: - An adaptation of the [[examples/mnist][MNIST tensorflow tutorial]] - A simple [[examples/seq2seq][sequence to sequence model]] which attempts to learn to translate pre-order into post-order. To running the examples can be done like so: #+BEGIN_SRC shell nix-env -iA nixpkgs.haskellPackages.styx nix-env -iA nixpkgs.cabal2nix styx configure cd examples/seq2seq make #+END_SRC ================================================ FILE: TypedFlow/Abstract.hs ================================================ {-# LANGUAGE InstanceSigs #-} {-| Module : TypedFlow.Abstract Description : Abstract Tensor representations Copyright : (c) Jean-Philippe Bernardy, 2018 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental This module provides operations on the abstract representation of tensor operations. It is not normally imported directly by users. -} {-# LANGUAGE ApplicativeDo #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UnicodeSyntax #-} {-# LANGUAGE CPP #-} #if __GLASGOW_HASKELL__ >= 806 {-# LANGUAGE NoStarIsType #-} #endif module TypedFlow.Abstract where import Control.Monad.RWS (RWS, tell, runRWS) import Control.Monad.State import Data.Proxy import Data.Type.Equality import GHC.TypeLits import Prelude hiding (RealFrac(..)) import qualified TypedFlow.Memo as Memo0 import TypedFlow.Types import TypedFlow.Broadcast import TypedFlow.Types.Proofs freeVarsT :: forall s t. KnownTyp t => KnownShape s => T s t -> [Int] freeVarsT x = result where f :: forall s' t'. T s' t' -> [Int] f = Memo0.memo (protoFreevars f) result = f x protoFreevars :: (forall s' t'. T s' t' -> [Int]) -> T s t -> [Int] protoFreevars rec = \case BroadcastT _ _ _ _ x -> rec x MapT _ s f x -> rec x <> rec (f (T (Variable (Ref (-789) s typeSTyp)))) Softmax _ _ x -> rec x DirectBroadcast _ _ _ _ x -> rec x GatherND _ _ _ x y -> rec x <> rec y Noise _ _ _ _ -> [] Where cond x y -> rec cond <> rec x <> rec y If cond x y -> rec cond <> rec x <> rec y T (Variable (Ref i _ _)) -> [i] T _ -> [] Unbroadcast _p _u x -> rec x UnOp _op _ x -> rec x MatMul _ _ _ _ x y -> rec x <> rec y BinOp _op _ _ _ _ _ x y -> rec x <> rec y Gather _is _s0 _m _s1 x ix -> rec x <> rec ix Transpose _ _t x -> rec x ReshapeFrom _s x -> rec x Concat _s0 _s1 xs -> mconcat $ htoList $ hmap (\(Catable _ x) -> K (rec x)) xs Convolution _bs _inChans _outChans _filterShape _s x filters -> rec x <> rec filters Pool _ _ _ _ _ x -> rec x _ -> error "protoFreevars: unhandled case" genTrainingPlaceholder :: Scalar TFBool genTrainingPlaceholder = T (ExternalVar (Ref "training_placeholder" typeSShape typeSTyp)) -- | Zeros zeros :: ∀ t (shape :: Shape). KnownNumeric t => KnownShape shape => (T shape t) zeros = constant $ knownNum @t $ 0 defaultT :: ∀ t (shape :: Shape). KnownShape shape => KnownTyp t => (T shape t) defaultT = case typeSTyp @t of STyp SFloat _ _ -> zeros STyp SInt _ _ -> zeros STyp SBool _ _ -> constant False _ -> error "defaultT: unhandled case" -- | Ones ones :: ∀ t (shape :: Shape). KnownShape shape => KnownNumeric t => (T shape t) ones = knownNum @t $ constant 1 -- | Identity matrix in dimensions n,n eye :: ∀ n t. KnownNat n => KnownNumeric t => (T '[n,n] t) eye = diag 1 diag :: ∀ n t. KnownTyp t => KnownNat n => T '[n] t -> T '[n,n] t diag = UnOp (Diag Sat) Unit expm :: ∀ n t. KnownNat n => KnownNumeric t => T '[n,n] t -> T '[n,n] t expm = UnOp (ExpM Sat) Unit -- | @k@=diagonal above which to zero elements. k = 0 is the main diagonal, k < 0 is below it and k > 0 is above. tril :: ∀ n t. KnownNat n => KnownNumeric t => Integer -> T '[n,n] t -> T '[n,n] t tril k = UnOp (ZeroTriangle Sat Lower k) Unit triu :: ∀ n t. KnownNat n => KnownNumeric t => Integer -> T '[n,n] t -> T '[n,n] t triu k = UnOp (ZeroTriangle Sat Upper k) Unit -- | Constant constant :: forall s t w. KnownShape s => KnownBits w => KnownKind t => HaskType ('Typ t w) -> T s ('Typ t w) constant c = appRUnit @s #> broadcastTT @s (scalar c) scalar :: forall t w. KnownBits w => KnownKind t => HaskType ('Typ t w) -> Scalar ('Typ t w) scalar = T . Constant reduceAll :: forall s t. KnownTyp t => KnownShape s => (∀n s'. (KnownTyp t,KnownShape s') => Axis n s' -> T s' t -> T (Take n s' ++ Drop ('Succ n) s') t) -> Tensor s t -> Tensor '[] t reduceAll op x = knownProduct @s ?> op axis0 (reshapeTo ((:*) (productS (typeSShape @s)) Unit) x) -- | Mean value of the input tensor. reduceMeanAll, reduceSumAll, reduceMaxAll, reduceMinAll :: ∀ (s :: Shape) t. KnownNumeric t => KnownShape s => Tensor s t -> Tensor '[] t reduceMaxAll = reduceAll reduceMax reduceMeanAll = reduceAll reduceMean reduceSumAll = reduceAll reduceSum reduceMinAll = reduceAll reduceMin sShapeTake' :: Axis n s -> SList' f s -> SList' f (Take n s) sShapeTake' AxZero _s = Unit sShapeTake' (AxSucc n) ((:*) x xs) = (:*) x (sShapeTake' n xs) sShapeDrop' :: Axis n s -> SList' f s -> SList' f (Drop n s) sShapeDrop' AxZero s = s sShapeDrop' (AxSucc n) ((:*) _ xs) = sShapeDrop' n xs sShapeDropSucc :: Axis n s -> SList' f s -> SList' f (Drop ('Succ n) s) sShapeDropSucc AxZero (_ :* s) = s sShapeDropSucc (AxSucc n) (_ :* xs) = sShapeDropSucc n xs -- | Internal. Use 'reduceSum', etc. instead. reduce :: ∀ n s t. KnownNumeric t => (KnownShape s) => ReduceOp -> Axis n s -> T s t -> T (Take n s ++ Drop ('Succ n) s) t reduce op n x = case axisSplitApp' n of Refl -> UnOp (Axis1Op (sShapeDropSucc n s) (ReduceOp (hlookup n s) op)) (sShapeTake' n s) x where s = typeSShape @s -- | Reduce along a given dimension reduceSum, reduceMean, reduceMax, reduceMin :: ∀n s t. (KnownNumeric t,KnownShape s) => Axis n s -> T s t -> T (Take n s ++ Drop ('Succ n) s) t reduceSum = reduce Sum reduceMean = reduce Mean reduceMax = reduce Max reduceMin = reduce Min -- | Sum along the first dimension reduceSum0 :: ∀ s' n t. KnownNat n => KnownNumeric t => KnownShape s' => Tensor (n ': s') t -> Tensor s' t reduceSum0 = reduceSum axis0 addN :: ∀ s t. KnownNumeric t => KnownShape s => [Tensor s t] -> Tensor s t addN [] = zeros addN ts = foldr1 (+) ts instance (KnownNumeric t, KnownShape s) => Num (T s t) where (+) = (⊕) (*) = (⊙) signum = unOp Sign fromInteger x = knownNum @t $ constant (fromIntegral x) abs = unOp Abs (-) = (⊝) negate = unOp Negate instance (KnownFloat b, KnownShape s) => Fractional (T s b) where fromRational x = knownAlgebraic @b $ constant (fromRational x :: HaskType b) (/) = (⊘) instance (KnownFloat b, KnownShape s) => Floating (T s b) where pi = knownAlgebraic @b $ constant pi exp = unFlOp Exp log = unFlOp Log sin = unFlOp Sin cos = unFlOp Cos asin = unFlOp Asin acos = unFlOp Acos sinh = unFlOp Sinh cosh = unFlOp Cosh asinh = unFlOp Asinh acosh = unFlOp Acosh tanh = unFlOp Tanh atan = unFlOp Atan atanh = unFlOp Atanh sqrt = unFlOp Sqrt -- | Pretend that the argument is a constant for the purposes of -- gradient computation stopGradient :: ∀ s t. KnownTyp t => KnownShape s => Tensor s t -> Tensor s t stopGradient = appRUnit @s #> UnOp StopGradient (typeSShape @s) -- | Divide tensors, broacasting along shape @s@ (⊘) :: forall s t. KnownAlgebraic t => KnownShape s => T s t -> T s t -> T s t (⊘) = binOp Divide -- | Divide tensors, broacasting along shape @s@ floorDiv :: forall s w. KnownBits w => KnownShape s => T s ('Typ 'Int w) -> T s ('Typ 'Int w) -> T s ('Typ 'Int w) floorDiv = binOp IntegerDiv -- | Indexwise equality test. equal :: forall s t. (KnownShape s, KnownTyp t) => Tensor s t -> Tensor s t -> Tensor s TFBool equal = binOp (Equal) -- | Indexwise operator (⊕), (⊝), (⊙) :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s t (⊝) = binOp Subtract (⊙) = binOp Multiply (⊕) = binOp Add maxT,minT :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s t maxT = binOp Maximum minT = binOp Minimum mkComplex :: KnownBits w => KnownShape s => Tensor s (Flt w) -> Tensor s (Flt w) -> Tensor s ('Typ 'Cmplx w) mkComplex = binOp MkComplex lessThan :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s TFBool lessThan = binOp (Comparision Less) lessOrEqualThan :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s TFBool lessOrEqualThan = binOp (Comparision LessOrEqual) greaterThan :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s TFBool greaterThan = binOp (Comparision Greater) logicAnd :: ∀ (s :: Shape). (KnownShape s) => Tensor s TFBool -> Tensor s TFBool-> Tensor s TFBool logicAnd = binOp (Logic And) infixl 7 ⊙,⊘ infixl 6 ⊕,⊝ -- | Matrix multiplication (note that shape @s@ is preserved) matmul :: forall m n o t. KnownNumeric t => KnownNat m => KnownNat o => KnownNat n => KnownTyp t => T '[n,o] t -> T '[o,m] t -> T '[n,m] t matmul = MatMul Unit Sat Sat Sat unOp :: forall s t. KnownShape s => KnownNumeric t => Num1Op -> T s t -> T s t unOp op = appRUnit @s #> UnOp (Num1Op op) (typeSShape @s) unFlOp :: forall s t. KnownBits t => KnownShape s => Float1Op -> T s (Flt t) -> T s (Flt t) unFlOp op = appRUnit @s #> UnOp (Float1Op op) (typeSShape @s) binOp :: forall s t u. KnownShape s => KnownTyp t => Simple2Op t u -> T s t -> T s t -> T s u binOp op = appRUnit @s #> BinOp (Simple2Op op) (typeSShape @s) Unit typeSTyp Unit typeSTyp conjugate :: ∀ s w. KnownShape s => KnownBits w => T s ('Typ 'Cmplx w) -> T s ('Typ 'Cmplx w) conjugate = appRUnit @s #> UnOp Conjugate (typeSShape @s) realPart :: ∀ s w. KnownShape s => KnownBits w => T s ('Typ 'Cmplx w) -> T s ('Typ 'Float w) realPart = appRUnit @s #> UnOp RealPart (typeSShape @s) sigmoid, relu, square, round, floor, hardSigmoid :: ∀ s t. (KnownShape s, KnownBits t) => Tensor s ('Typ 'Float t) -> Tensor s ('Typ 'Float t) sigmoid = unFlOp Sigmoid hardSigmoid = unFlOp HardSigmoid square = unOp Square relu = unFlOp Relu floorMod :: ∀ s t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s t floorMod = binOp FloorMod -- Unfortunately RealFrac is utterly broken; so we have to do this: round = unFlOp Round floor = unFlOp Floor -- | Take a slice at dimension n from i to j. slice :: forall i j s t n. KnownTyp t => KnownShape s => KnownNat j => KnownNat i => (i <= j, j <= At n s, KnownLen s) => Axis n s -> Tensor s t -> Tensor (Take n s ++ ((j-i) ': Drop ('Succ n) s)) t slice n = case axisSplitApp' n of Refl -> UnOp (Axis1Op (sShapeDropSucc n s) (SliceOp (Proxy @(j-i)) (hlookup n s) (natVal (Proxy @i)) (natVal (Proxy @j)))) (sShapeTake' n s) where s = typeSShape @s slice1 :: forall i j m n s t. KnownShape s => KnownNat m => KnownNat n => KnownTyp t => KnownNat j => KnownNat i => (i <= j, j <= m, KnownLen s) => Tensor (n ': m ': s) t -> Tensor (n ': (j-i) ': s) t slice1 = slice @i @j axis1 slice0 :: forall i j m s t. KnownShape s => KnownNat m => KnownTyp t => KnownNat j => KnownNat i => (i <= j, j <= m, KnownLen s) => Tensor (m ': s) t -> Tensor ((j-i) ': s) t slice0 = slice @i @j axis0 -- | Concatenate tensors on dimension @n@. Recommended: use @zipWithTT (concat0 ...)@ instead. concatT :: ∀ n d1 d2 s t. KnownNat d2 => KnownNat d1 => KnownShape s => (KnownTyp t, (d1+d2) ~ At n s) => Axis n s -> T (Take n s ++ (d1 ': Drop ('Succ n) s)) t -> T (Take n s ++ (d2 ': Drop ('Succ n) s)) t -> T s t concatT n = case axisSplitApp' n of Refl -> concatT' (sShapeTake' n s) d1 d2 (sShapeDropSucc n s) where s = typeSShape @s; d1 = natSat @d1; d2 = natSat @d2 -- | Concatenate tensors on the first dimension concat0, (©) :: ∀ d1 d2 ys t. KnownTyp t => KnownShape ys => KnownNat d2 => KnownNat d1 => (KnownLen ys) => T (d1 ': ys) t -> T (d2 ': ys) t -> T ((d1 + d2) ': ys) t concat0 = concatT axis0 (©) = concat0 -- | Concatenate tensors on the second dimension concat1 :: ∀ n ys d1 d2 t. KnownShape ys => KnownNat n => KnownNat d2 => KnownNat d1 => KnownTyp t => (KnownLen ys) => T (n ': d1 ': ys) t -> T (n ': d2 ': ys) t -> T (n ': (d1 + d2) ': ys) t concat1 = concatT axis1 -- | Add an extra dimension at axis (@n@) of size 1. expandDim :: forall n s t. KnownTyp t => KnownShape s => (PeanoNat n <= Length s) => Tensor s t -> Tensor (Take n s ++ (1 ': Drop n s)) t expandDim x = -- Product (Take n s ++ (1 ': Drop n s)) prodHomo @(Take n s) @(1' : Drop n s) #> -- Product (Take n s) * Product (Drop n s) prodHomo @(Take n s) @(Drop n s) #> -- Product (Take n s ++ (1 ': Drop n s)) takeDrop @s @n #> -- Product s reshapeFrom (typeSShape @s) x -- +expandDim :: forall n s t. KnownTyp t => KnownShape s => Axis n s -> Tensor s t -> Tensor (Take n s ++ (1 ': Drop n s)) t -- +expandDim ax x = case expandDimProof ax s of Refl -> reshapeFrom s x -- | Add an extra dimension at axis (0) of size 1. expandDim0 :: ∀ s t. KnownShape s => KnownTyp t => KnownLen s => Tensor s t -> Tensor (1 ': s) t expandDim0 = reshape -- | Add an extra dimension at axis (1) of size 1. expandDim1 :: ∀ n s t. KnownNat n => KnownTyp t => KnownShape s => Tensor (n ': s) t -> Tensor (n ': 1 ': s) t expandDim1 = reshape -- | Flatten all the dimensions of the tensor flattenAll :: forall s t. KnownTyp t => KnownShape s => Tensor s t -> Tensor '[Product s] t flattenAll = knownProduct @s ?> reshape inflateAll :: forall s t. KnownTyp t => KnownShape s => Tensor '[Product s] t -> Tensor s t inflateAll = knownProduct @s ?> reshape squeeze0 :: ∀ s t. KnownTyp t => (KnownShape s) => Tensor (1 ': s) t -> Tensor s t squeeze0 = reshape atShape :: SList s -> T s t -> T s t atShape _ x = x -- | Reshape a tensor so that the last two dimensions are collapsed flattenN2 :: ∀ s m n t. KnownTyp t => (KnownNat m, KnownNat n, KnownShape s) => Tensor (s ++ '[m,n]) t -> Tensor (s ++ '[m*n]) t flattenN2 = prodHomo @s @'[m,n] #> prodHomo @s @'[m*n] #> knownAppend @s @'[m*n] ?> knownAppend @s @'[m,n] ?> reshape -- | Reshape a tensor so that the first three dimensions are collapsed flatten3 :: ∀ m n o s t. KnownTyp t => (KnownNat m, KnownNat n, KnownNat o, KnownShape s) => Tensor (m ': n ': o ': s) t -> Tensor (m*n*o ': s) t flatten3 = -- (m * (n * (o * Product s))) prodAssoc @m @n @(o * Product s) #> -- (m * n) * (o * Product s) prodAssoc @(m * n) @o @(Product s) #> -- ((m * n) * o) * Product s reshape -- | Reshape a tensor so that the first two dimensions are collapsed flatten12 :: ∀ m n o s t. KnownTyp t => KnownNat o => (KnownNat m, KnownNat n, KnownShape s) => Tensor (o ': m ': n ': s) t -> Tensor (o ': m*n ': s) t flatten12 = prodAssoc @m @n @(Product s) #> reshape -- | Reshape a tensor so that the first dimension is expanded into three. inflate3 :: ∀ m n o s t. KnownTyp t => (KnownNat m, KnownNat n, KnownNat o, KnownShape s) => Tensor (m*n*o ': s) t -> Tensor (m ': n ': o ': s) t inflate3 = -- (m * (n * (o * Product s))) prodAssoc @m @n @(o * Product s) #> -- (m * n) * (o * Product s) prodAssoc @(m * n) @o @(Product s) #> -- ((m * n) * o) * Product s reshape -- | Reshape a tensor so that the first two dimensions are collapsed inflate12 :: ∀ m n o s t. KnownTyp t => KnownNat o => (KnownNat m, KnownNat n, KnownShape s) => Tensor (o ': m*n ': s) t -> Tensor (o ': m ': n ': s) t inflate12 = prodAssoc @m @n @(Product s) #> reshape -- | Access the last element in a tensor (in the 0th dimension) last0 :: ∀ n s t. KnownShape s => KnownTyp t => KnownNat n => KnownLen s => T (n ': s) t -> Tensor s t last0 = nth0 (natVal (Proxy @n) - 1) -- | Access the nth element in a tensor (in the 0th dimension) nth0 :: ∀ n s t. KnownTyp t => KnownNat n => KnownShape s => Integer -> T (n ': s) t -> Tensor s t nth0 i x = UnOp (Axis1Op (typeSShape @s) (AccessOp (natSat @n) i)) Unit x -- | Access the nth element in a tensor (in the 0th dimension), with a static index nth0' :: ∀ n m s t. KnownNat m => KnownTyp t => KnownShape s => KnownNat n => KnownLen s => n < m => T (m ': s) t -> Tensor s t nth0' = nth0 (natVal (Proxy @n)) vecToNP :: forall a f n k. (a -> f 1) -> V n a -> (forall xs. Sum xs ~ n => NP f xs -> k) -> k vecToNP _f VUnit k = k Unit vecToNP f (x :** xs) k = vecToNP f xs $ \xs' -> k (f x :* xs') stackT :: ∀ s0 s (n::Nat) t. KnownShape s => KnownShape s0 => KnownNat n => (KnownLen s0) => V n (T (s0 ++ s) t) -> Tensor (s0 ++ (n ': s)) t stackT v = vecToNP @(T (s0++s) t) @(Catable s0 s t) (\x -> (Catable (natSat @1) $ (prodHomoS s0 s #> prodHomoS s0 (natSat @1 :* s) #> knownAppend @s0 @s ?> knownSShape (s0 .+. natSat @1 :* s) ?> reshape x))) v $ (Concat (typeSShape @s0) (typeSShape @s)) where s = typeSShape @s; s0 = typeSShape @s0 -- | Concatenate @n@ tensors along the first dimension stack0 :: ∀ s (n::Nat) t. KnownNat n => KnownShape s => (KnownLen s) => V n (T s t) -> Tensor (n ': s) t stack0 = stackT @'[] -- | Concatenate @n@ tensors along the second dimension stack1 :: ∀ s (n::Nat) m t. KnownNat n => KnownNat m => KnownShape s => (KnownLen s) => V n (T (m ': s) t) -> Tensor (m ': n ': s) t stack1 = stackT @'[m] -- | Concatenate @n@ tensors along the last dimension stackN :: ∀ s (n::Nat) t. KnownNat n => KnownShape s => V n (T s t) -> Tensor (s ++ '[n]) t stackN = appRUnit @s #> stackT @s @'[] -- | Split a tensors into @n@ tensors along the first dimension unstack0 :: ∀ s (n::Nat) t. KnownTyp t => KnownNat n => KnownShape s => (KnownLen s) => Tensor (n ': s) t -> V n (T s t) unstack0 x = fmap (`nth0` x) (vcount @n) -- | Stack a tensor vector. (To be used on literal lists of tensors.) litStack0 :: KnownShape s => KnownLen xs => TV s t xs -> Tensor (Length xs ': s) t litStack0 tv = knownSList tv ?> stack0 $ toV tv where toV :: TV s t xs -> V (Length xs) (T s t) toV Unit = VUnit toV (K x :* xs) = x :** toV xs -- | Generate a mask of given length for each sequence. sequenceMask :: forall maxlen. KnownNat maxlen => Tensor '[] Int32 -> Tensor '[maxlen] TFBool sequenceMask lens = mapT (lens `lessThan`) (range @maxlen) -- | simple broadcasting of a tensor (like a zero-arity map) broadcastT :: forall n s t. KnownShape s => KnownNat n => KnownTyp t => KnownLen s => T s t -> T (n ': s) t broadcastT x = BroadcastT Nothing False (natSat @n) typeSShape x -- | simple broadcasting of a tensor broadcastTT :: forall a s t. KnownShape s => KnownTyp t => KnownShape a => KnownLen s => T s t -> T (a ++ s) t broadcastTT x = prodHomo @a @s #> knownProduct @a ?> knownAppend @a @s ?> reshape (broadcastT @(Product a) x) -- | Map a function along the first dimension of a tensor mapT :: forall n s r t u. KnownShape s => KnownNat n => KnownTyp t => KnownLen r => KnownLen s => (T s t -> T r u) -> T (n ': s) t -> T (n ': r) u mapT f x = MapT Sat typeSShape f x -- | Map a function along the few first dimensions of a tensor, given by the first type parameter mapTT :: forall a s t r u. KnownShape r => KnownShape a => KnownTyp u => KnownLen r => KnownShape s => KnownTyp t => (T s t -> T r u) -> T (a ++ s) t -> T (a ++ r) u mapTT f x = prodHomo @a @r #> prodHomo @a @s #> knownProduct @a ?> knownAppend @a @r ?> knownAppend @a @s ?> reshape (mapT @(Product a) f (reshape x)) -- | zip a function along the first dimension of two tensors tensors zipWithT :: forall (n :: Nat) (s :: [Nat]) (t :: Typ) (s1 :: [Nat]) (t1 :: Typ) (s2 :: Shape) (t2 :: Typ). KnownShape s => KnownShape s1 => KnownNat n=> KnownTyp t => KnownTyp t1 => (T s t -> T s1 t1 -> T s2 t2) -> Tensor (n ': s) t -> Tensor (n ': s1) t1 -> Tensor (n ': s2) t2 zipWithT f x y = ZipT Sat typeSShape typeSShape f x y -- | zip a function along the few first dimensions of a tensor, given by the first type parameter zipWithTT :: forall a (s :: [Nat]) (s1 :: [Nat]) (s2 :: Shape) (t :: Typ) (t1 :: Typ) (t2 :: Typ). KnownTyp t1 => KnownTyp t => KnownShape s => KnownShape s1 => KnownShape a => KnownShape s2 => KnownTyp t2 => (T s t -> T s1 t1 -> T s2 t2) -> Tensor (a ++ s) t -> Tensor (a ++ s1) t1 -> Tensor (a ++ s2) t2 zipWithTT f x y = prodHomo @a @s1 #> prodHomo @a @s2 #> prodHomo @a @s #> knownProduct @a ?> knownAppend @a @s1 ?> knownAppend @a @s2 ?> knownAppend @a @s ?> reshape (zipWithT @(Product a) f (reshape x) (reshape y)) zipWith3T :: forall (n :: Nat) (s :: [Nat]) (t :: Typ) (s1 :: [Nat]) (t1 :: Typ) (s2 :: Shape) (t2 :: Typ) (s3 :: Shape) (t3 :: Typ). KnownShape s => KnownShape s1 => KnownShape s2 => KnownShape s3 => KnownNat n => KnownTyp t3 => KnownTyp t => KnownTyp t1 => KnownTyp t2 => (T s t -> T s1 t1 -> T s2 t2 -> T s3 t3) -> Tensor (n ': s) t -> Tensor (n ': s1) t1 -> Tensor (n ': s2) t2 -> Tensor (n ': s3) t3 zipWith3T = Zip3T Sat typeSShape typeSShape typeSShape -- | Size-preserving convolution operation. convolution :: forall outputChannels filterSpatialShape inChannels s t. KnownShape s => KnownNat inChannels => KnownNat outputChannels => KnownShape filterSpatialShape => KnownAlgebraic t => Length filterSpatialShape <= 3 => Length s ~ Length filterSpatialShape => T (s ++ '[inChannels]) t -- ^ input tensor -> T (filterSpatialShape ++ '[inChannels,outputChannels]) t -- ^ filters -> T (s ++ '[outputChannels]) t convolution x filters = knownAppend @s @'[outputChannels] ?> knownAppend @s @'[inChannels] ?> squeeze0 (Convolution (natSat @1) (natSat @inChannels) (natSat @outputChannels) (typeSShape @filterSpatialShape) (typeSShape @s) (expandDim0 x) filters) -- | Softmax along the first dimension softmaxInternal :: forall bs n w. KnownNat bs => KnownBits w => KnownNat n => T '[bs,n] ('Typ 'Float w) -> T '[bs,n] ('Typ 'Float w) softmaxInternal = Softmax (natSat @bs) (natSat @n) softmax0 :: forall n w. KnownBits w => KnownNat n => T '[n] (' Typ 'Float w) -> T '[n] ('Typ 'Float w) softmax0 = reshape . softmaxInternal . reshape @[1,n] -- | Softmax along the second dimension softmax1 :: forall n m w. KnownBits w => KnownNat n => KnownNat m => T '[m,n] ('Typ 'Float w) -> T '[m,n] ('Typ 'Float w) softmax1 = mapT softmax0 argmaxInternal :: forall n s0 s1 t u. KnownNat n => KnownNumeric t => KnownBits u => Sat KnownNat n -> SShape s0 -> SShape s1 -> T (s0 ++ (n ': s1)) t -> T (s0 ++ s1) ('Typ 'Int u) argmaxInternal _n s0 s1 = UnOp (Axis1Op s1 (ArgMax (natSat @n))) s0 axisSplitApp :: Axis n s -> (Take n s ++ Drop n s) :~: s axisSplitApp AxZero = Refl axisSplitApp (AxSucc n) = case axisSplitApp n of Refl -> Refl axisSplitApp' :: Axis n s -> (Take n s ++ (At n s ': Drop ('Succ n) s)) :~: s axisSplitApp' AxZero = Refl axisSplitApp' (AxSucc n) = case axisSplitApp' n of Refl -> Refl -- | Argmax along axis @n@ argmax :: forall m n u s t. (KnownShape s, KnownBits u, KnownNat m, KnownNumeric t) => Axis n s -> Tensor (Take n s ++ (m ': Drop n s)) t -> Tensor s ('Typ 'Int u) argmax n = case axisSplitApp n of Refl -> argmaxInternal (natSat @m) (sShapeTake' n (typeSShape @s)) (sShapeDrop' n s) where s = typeSShape @s -- | Argmax along the first dimension argmax0 :: forall u n s t. (KnownNat n, KnownShape s, KnownBits u, KnownNumeric t) => T (n ': s) t -> T s ('Typ 'Int u) argmax0 = argmaxInternal (natSat @n) Unit (typeSShape @s) -- | Argmax along the second dimension argmax1 :: forall u m n s t. (KnownNat n, KnownNat m, KnownShape s, KnownBits u, KnownNumeric t) => T (m ': n ': s) t -> T (m ': s) ('Typ 'Int u) argmax1 = argmaxInternal (natSat @n) (natSat @m :* Unit) (typeSShape @s) -- argmax1 = mapT argmax0 -- equivalent? -- | Cast the element type. cast :: forall u s t. KnownTyp t => KnownShape s => KnownTyp u => T s t -> T s u cast = appRUnit @s #> UnOp Cast (typeSShape @s) -- | (dense) softmax cross entropy with logits. softmaxCrossEntropyWithLogits :: forall numClasses. KnownNat numClasses => Tensor '[numClasses] Float32 -- ^ labels -> Tensor '[numClasses] Float32 -- ^ logits -> Tensor '[] Float32 softmaxCrossEntropyWithLogits = BinOp SoftmaxCrossEntropyWithLogits Unit (typeSShape @ '[numClasses]) typeSTyp (typeSShape @ '[numClasses]) typeSTyp -- | Computes sigmoid cross entropy given logits. Measures the -- probability error in discrete classification tasks in which each -- class is independent and not mutually exclusive. For instance, one -- could perform multilabel classification where a picture can contain -- both an elephant and a dog at the same time. See -- https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits sigmoidCrossEntropyWithLogits :: forall s w. KnownBits w => KnownShape s => Tensor s (Flt w) -- ^ labels -> Tensor s (Flt w) -- ^ logits -> Tensor s (Flt w) sigmoidCrossEntropyWithLogits = appRUnit @s #> BinOp SigmoidCrossEntropyWithLogits (typeSShape @s) Unit typeSTyp Unit typeSTyp -- | sparse softmax cross entropy with logits. sparseSoftmaxCrossEntropyWithLogits :: forall numClasses t. KnownNat numClasses => KnownBits t => Tensor '[] Int32 -- ^ desired label -> Tensor '[numClasses] (Flt t) -- ^ predictions for each label -> Tensor '[] (Flt t) sparseSoftmaxCrossEntropyWithLogits = BinOp SparseSoftmaxCrossEntropyWithLogits Unit Unit typeSTyp (typeSShape @ '[numClasses]) typeSTyp reverseT :: KnownTyp t => KnownNat n => T '[n] t -> T '[n] t reverseT = UnOp (Axis1Op Unit (ReverseT Sat)) Unit -- | One hot vector along axis 0 oneHot0 :: forall numClasses w s t. KnownNat numClasses => KnownNumeric t => KnownBits w => (KnownShape s) => Tensor s ('Typ 'Int w) -> Tensor (numClasses ': s) t oneHot0 = UnOp (Axis1Op (typeSShape @s) (OneHot Sat)) Unit -- | One hot vector along axis 1 oneHot1 :: forall numClasses w s m t. KnownBits w =>KnownShape s => KnownNat numClasses => KnownNat m => KnownNumeric t => Tensor (m ': s) ('Typ 'Int w) -> Tensor (m ': numClasses ': s) t oneHot1 = mapT oneHot0 -- | Generate a random tensor whose distribution is given. A new noise -- is sampled for each element in a batch. noise :: KnownShape s => Distribution s t -> Gen (T s t) noise d = do noiseId <- GPId -- necessary for correct broadcasting behaviour return $ Noise noiseId Unit typeSShape d -- | Clip a tensor clipByValue :: forall s t. KnownShape s => KnownBits t => Float -> Float -> T s (Flt t) -> T s (Flt t) clipByValue lo hi = appRUnit @s #> UnOp (Float1Op (ClipByValue lo hi)) (typeSShape @s) -- | (where_ c x y)[i] = if c[i] then x[i] else y[i] where_ :: T s TFBool -> T s t -> T s t -> T s t where_ = Where -- | Selection of a tensor (note: this is a strict operation) if_ :: forall s t. KnownShape s => Scalar TFBool -> T s t -> T s t -> T s t if_ = If -- FIXME: part of the workaround for https://github.com/tensorflow/tensorflow/issues/21901 -- if_ x = appRUnit @s $ where_ (broadcastTT @s x) -- | @(gather x ix)[k] = x[ix[k]]@. See https://www.tensorflow.org/api_docs/python/tf/gather gather :: forall n indexShape s t. KnownShape s => KnownNat n => KnownShape indexShape => T (n ': s) t -> T indexShape Int32 -> T (indexShape ++ s) t gather = Gather typeSShape Unit (natSat @n) typeSShape -- gather params ix = GatherND (typeSShape @'[n]) (typeSShape @s) (typeSShape @indexShape) params $ -- prodHomo @indexShape @'[1] $ -- (reshapeAuto ix) -- | @(lookup i xs) = xs[i]@. This function returns an element of a -- tensor at a dynamic index. This is a version of 'gather' -- specialised to a scalar index. lookupT :: KnownShape xs => KnownNat n => Scalar Int32 -> Tensor (n ': xs) t -> Tensor xs t lookupT ix xs = gather xs ix -- | x by y maxpool layer. maxPool2D :: forall windowx windowy height width channels t. KnownNat height => KnownNat width => KnownNat channels => (KnownNat windowx, KnownNat windowy, KnownBits t) => T '[windowx*width,windowy*height,channels] (Flt t) -> T '[width,height,channels] (Flt t) maxPool2D x = squeeze0 (Pool (natSat @1) (typeSShape @'[windowx,windowy]) MaxPool (natSat @channels) (typeSShape @'[width,height]) (expandDim0 x)) -- | maxpool layer. window size is the first type argument. maxPool1D :: forall window width channels t. KnownNat width => KnownNat channels => (KnownNat window,KnownBits t) => T '[window*width,channels] (Flt t) -> T '[width,channels] (Flt t) maxPool1D x = squeeze0 (Pool (natSat @1) (typeSShape @'[window]) MaxPool (natSat @channels) (typeSShape @'[width]) (expandDim0 x)) doExtractVars :: Gen a -> (a, GState, [VarInfo]) doExtractVars p = runRWS (extractVars p) () initialGstate extractVars :: Gen a -> RWS () [VarInfo] GState a extractVars (GPState f) = state f extractVars GPId = do GState {..} <- get put GState {nextVar=nextVar+1,..} return nextVar extractVars (GPVariable trainable name i) = do -- i <- mapM extractVars initial case i of Nothing -> return () Just i' -> when (not (null (freeVarsT i'))) $ error "aaaaaaaaarrrrghhh" GState {..} <- get let r = Ref name typeSShape typeSTyp tell [VarInfo trainable r i] return r extractVars (GPApp a b) = do f <- extractVars a; x <- extractVars b; return (f x) extractVars (GPBind a f) = do a' <- extractVars a extractVars (f a') extractVars (GPReturn x) = return x ================================================ FILE: TypedFlow/Broadcast.hs ================================================ {-# LANGUAGE InstanceSigs #-} {-| Module : TypedFlow.Abstract Description : Abstract Tensor representations Copyright : (c) Jean-Philippe Bernardy, 2018 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental This module provides operations on the abstract representation of tensor operations. It is not normally imported directly by users. -} {-# LANGUAGE ApplicativeDo #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UnicodeSyntax #-} {-# LANGUAGE CPP #-} #if __GLASGOW_HASKELL__ >= 806 {-# LANGUAGE NoStarIsType #-} #endif module TypedFlow.Broadcast ( -- * broadcasting doBroadcast,doBroadcastSingle,mapPlaceHolders, ConsSh, unopInputShape, -- * helpers which are also useful elsewhere -- ** reshapes reshape, reshapeAuto, reshapeFrom, reshapeTo, inflate2, flatten2, permToFun, -- ** transpositions transpose01, transposeN, transposeN', transposeN01, -- ** others concatT', range, ) where import Control.Monad.State -- import Data.Kind (Type,) import Data.Proxy import Data.Type.Equality import GHC.TypeLits import Prelude hiding (RealFrac(..)) import System.IO.Unsafe import TypedFlow.Memo2 hiding (Comp) import TypedFlow.Types (T(..), type (∘)(..)) import TypedFlow.Types hiding (T) import TypedFlow.Types.Proofs data GS = GS { gsUnique :: Unique } type G = StateT GS IO runG :: Unique -> G x -> x runG u m = fst (unsafePerformIO (runStateT m GS { gsUnique = u })) doBroadcastSingle :: forall s t. (KnownShape s, KnownTyp t) => T s t -> T s t doBroadcastSingle x = case doBroadcast @'[ '("_doBroadcastSingle" , s , t) ] (PHT x :* Unit) of PHT x' :* Unit -> x' doBroadcast :: All KnownPlaceholder ps => Placeholders ps -> Placeholders ps doBroadcast phs = runG 0 $ do F3m' bc <- mkBroadcastFn let broadcast :: forall n s t. BroadcastFn n s t broadcast = unwrapBCFn bc F2m' gBC' <- mkGenerateBC broadcast let generateBC :: forall s t. GenBCFn s t generateBC = unwrapGBCFn gBC' generateBCMany :: forall ps. All KnownPlaceholder ps => Placeholders ps -> G (Placeholders ps) generateBCMany = \case Unit -> return Unit (PHT x :* xs) -> do x' <- generateBC x xs' <- generateBCMany xs return (PHT x' :* xs') generateBCMany phs getUnique :: G Unique getUnique = do u <- gets ((1+) . gsUnique) modify $ \GS {} -> GS {gsUnique = u,..} return u generateBC' :: (forall n s t proxy. KnownTyp t => KnownShape s => KnownNat n => Unique -> Bool -> proxy n -> T s t -> G (T (n : s) t)) -> (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> G (T s' t')) -> forall s t. KnownTyp t => SShape s -> T s t -> G (T s t) generateBC' broadcast rec (n@Sat :* sR) (Zip3T _ _s1 _s2 _s3 f x y z) = knownSShape sR ?> do u <- getUnique -- ATTN: it is critical not to do recursive calls to x,y,z here. Doing so would create new nodes, loosing sharing, and creating problems down the line. a' <- rec sR (f (Unbroadcast n u x) (Unbroadcast n u y) (Unbroadcast n u z)) broadcast u False n a' generateBC' broadcast rec (n@Sat :* sR) (ZipT _ _s1 _s2 f x y) = knownSShape sR ?> do u <- getUnique a' <- rec sR (f (Unbroadcast n u x) (Unbroadcast n u y)) broadcast u False n a' generateBC' broadcast rec (n@Sat :* sR) (MapT _ _s' f x) = knownSShape sR ?> do u <- getUnique a' <- rec sR (f (Unbroadcast n u x)) broadcast u False n a' generateBC' broadcast rec (n@Sat :* sR) (BroadcastT maybeUnique varyNoise _ _s' a) = knownSShape sR ?> do u <- case maybeUnique of Nothing -> getUnique Just u' -> return u' a' <- rec sR a broadcast u varyNoise n a' generateBC' _ _ _ (n@T {}) = return n generateBC' _ _ _ (n@Noise {}) = return n generateBC' _ rec _ (BinOp op s0 s1 t1 s2 t2 x y) = knownTyp t1 $ knownTyp t2 $ BinOp op s0 s1 t1 s2 t2 <$> (rec (s0 .+. s1) x) <*> (rec (s0 .+. s2) y) generateBC' _ rec _ (UnOp op s0 x) = UnOp op s0 <$> rec (s0 .+. unopInputShape op) x generateBC' _ rec sR (Unbroadcast p u' x) = Unbroadcast p u' <$> rec (p :* sR) x generateBC' _ rec _ (DirectBroadcast s0 s1 s2 s3 x) = DirectBroadcast s0 s1 s2 s3 <$> (rec (s0 .+. s2) x) generateBC' _ rec _ (ReshapeFrom s0 x) = reshapeFrom s0 <$> rec s0 x generateBC' _ rec _ (Transpose s0 t x) = Transpose s0 t <$> (rec s0 x) generateBC' _ rec _ (Concat s0 s1 xs) = Concat s0 s1 <$> hTraverse (\(Catable m x) -> Catable m <$> (rec (s0 .+. m :* s1) x)) xs generateBC' _ rec _ (Gather is s0 m s1 x ix) = Gather is s0 m s1 <$> (rec (s0 .+. m :* s1) x) <*> rec (s0 .+. is) ix generateBC' _ rec _ (GatherND cs es is x ix) = GatherND cs es is <$> (rec (cs .+. es) x) <*> (rec (is *: sListLenAsNat cs) ix) generateBC' _ rec _ (MatMul s0 a b c x y) = MatMul s0 a b c <$> (rec (s0 .+. a :* b :* Unit) x) <*> (rec (s0 .+. b :* c :* Unit) y) generateBC' _ rec sR (Where cond x y) = Where <$> rec sR cond <*> rec sR x <*> rec sR y generateBC' _ rec sR (If cond x y) = If <$> rec Unit cond <*> rec sR x <*> rec sR y generateBC' _ rec _ (Convolution bs@Sat inChans outChans filterShape s0 x filters) = Convolution bs inChans outChans filterShape s0 <$> (rec (bs :* (s0 *: inChans)) x) <*> (rec (filterShape .+. inChans :* outChans :* Unit) filters) generateBC' _ rec _ (Pool bs@Sat window pt numChans outSpatial x) = Pool bs window pt numChans outSpatial <$> rec (bs :* (zipWithMulSShapes window outSpatial *: numChans)) x generateBC' _ rec _ (Softmax bs n x) = Softmax bs n <$> (rec (bs :* n :* Unit) x) generateBC' _ _ _ _ = error "generateBC': unhandled case" (<&&>) :: Applicative f => f Bool -> f Bool -> f Bool x <&&> y = (&&) <$> x <*> y -- | True if the argument does not contain an expression which should be broadcast. protoFinished :: Unique -> Bool -> (forall s' t'. Unique -> Bool -> T s' t' -> G Bool) -> T s t -> G Bool protoFinished u varyNoise rec0 = let rec :: forall s t. T s t -> G Bool rec = rec0 u varyNoise in \case BroadcastT _ _ _ _s a -> rec a MapT _ s f x -> rec x <&&> rec (f (T (Variable (Ref 0 s typeSTyp)))) ZipT _ s0 s1 f x y -> rec x <&&> rec y <&&> rec (f (T (Variable (Ref 0 s0 typeSTyp))) (T (Variable (Ref 0 s1 typeSTyp)))) Zip3T _ s0 s1 s2 f x y z -> rec x <&&> rec y <&&> rec z <&&> rec (f (T (Variable (Ref 0 s0 typeSTyp))) (T (Variable (Ref 0 s1 typeSTyp))) (T (Variable (Ref 0 s2 typeSTyp)))) Softmax _ _ x -> rec x DirectBroadcast _ _ _ _ x -> rec x GatherND _ _ _ x y -> rec x <&&> rec y Noise _ _ _ _ -> return (not varyNoise) Where cond x y -> rec cond <&&> rec x <&&> rec y If cond x y -> rec cond <&&> rec x <&&> rec y T _ -> return True Unbroadcast _p u' _x -> return (u /= u') UnOp _op _ x -> rec x MatMul _ _ _ _ x y -> rec x <&&> rec y BinOp _op _ _ _ _ _ x y -> rec x <&&> rec y Gather _is _s0 _m _s1 x ix -> rec x <&&> rec ix Transpose _ _t x -> rec x ReshapeFrom _s x -> rec x Concat _s0 _s1 xs -> (and . htoList) <$> hTraverse (\(Catable _ x) -> K <$> rec x) xs Convolution _bs _inChans _outChans _filterShape _s x filters -> rec x <&&> rec filters Pool _ _ _ _ _ x -> rec x -- _ -> error "protoFinished: unhandled case" data K02 t x y = K02 {fromK02 :: t} mkFinished :: G (F2m G (Sig02 Bool (Sig02 Unique T)) (K02 Bool) ) -- forall s' t'. Unique -> Bool -> T s' t' -> G (F2m _) mkFinished = memo2 (ordMap @Bool `containing02` (ordMap @Unique `containing02` snMap2 @T)) $ \rec (Ex02 u (Ex02 v x)) -> K02 <$> protoFinished v u (unwrapFin rec) x unwrapFin :: ((Sig02 Bool (Sig02 Unique T)) s t -> G (K02 Bool s t)) -> Unique -> Bool -> T s t -> G Bool unwrapFin f u v x = fromK02 <$> f (Ex02 v (Ex02 u x)) data KT s t where KT :: STyp t -> SShape s -> KT s t type GenBCFn s t = (KnownTyp t, KnownShape s) => T s t -> G (T s t) unwrapGBCFn :: forall s t. (T s t -> KT s t -> G (T s t)) -> GenBCFn s t unwrapGBCFn f x' = f x' (KT typeSTyp typeSShape) -- isBroadcastT :: T s t -> Bool -- isBroadcastT (BroadcastT {}) = True -- isBroadcastT _ = False mkGenerateBC :: (forall n s t. BroadcastFn n s t) -> G (F2m' G T KT T) mkGenerateBC broadcast = memo2' (snMap2 @T) $ \rec x (KT t s) -> knownTyp t $ do r <- generateBC' broadcast (\sh' x' -> rec x' (KT typeSTyp sh')) s x -- when (isBroadcastT r) $ liftIO $ putStrLn "YIKES!" return r newtype BC'd (n :: Nat) (s :: Shape) (t :: Typ) = BC'd {fromBC'd :: (T (n : s) t)} data KTn n s t where KTn :: STyp t -> SShape s -> KTn n s t type BroadcastFn n s t = forall proxy. (KnownNat n, KnownShape s, KnownTyp t) => Unique -> Bool -> proxy n -> T s t -> G (T (n : s) t) unwrapBCFn :: ((Sig03 Unique (Sig03 Bool (Sig12 (Sat KnownNat) T))) n s t -> KTn n s t -> G (BC'd n s t)) -> BroadcastFn n s t unwrapBCFn f u v _n x' = fromBC'd <$> f (Ex03 u (Ex03 v (Ex12 natSat x'))) (KTn typeSTyp typeSShape) mkBroadcastFn :: G (F3m' G (Sig03 Unique (Sig03 Bool (Sig12 (Sat KnownNat) T))) KTn BC'd) mkBroadcastFn = do F2m fin <- mkFinished memo3' (ordMap @Unique `containing03` (ordMap @Bool `containing03` (verifMap1 @(Sat KnownNat) `containing12` snMap2 @T))) $ \rec (Ex03 u (Ex03 v (Ex12 n x))) (KTn st sh) -> BC'd <$> protoBroadcast u v n (\sh' x' -> fromBC'd <$> rec (Ex03 u (Ex03 v (Ex12 n x'))) (KTn typeSTyp sh')) (unwrapFin fin u v) st sh x class ConsSh (x :: Nat) (p :: (Symbol,Shape,Typ)) instance Fun (ConsSh x) where type Ap (ConsSh x) p = '(Frst3 p,x ': Scnd3 p,Thrd3 p) -- -- | Turns a tensor of indices in a container into a tensor of indices -- -- in a container of higher rank. The added indexed dimension -- -- corresponds to the first dimension of the index. -- broadcastIndex :: forall n containerShape indexShape w. -- KnownBits w => Sat KnownNat n -> -- SShape containerShape -> -- SShape indexShape -> -- IndexTensor (n ': indexShape) containerShape w -> -- IndexTensor (n ': indexShape) (n ': containerShape) w -- broadcastIndex n cs = broadcastIndex' n (sListLenAsNat cs) broadcastIndex' :: forall n containerRank indexShape w. KnownBits w => Sat KnownNat n -> Sat KnownNat containerRank -> SShape indexShape -> T (n ': indexShape ++ '[containerRank]) ('Typ 'Int w) -> T (n ': indexShape ++ '[1 + containerRank]) ('Typ 'Int w) broadcastIndex' n@(Comp Dict) cr is ix = concatT' ((:*) n is) (natSat @1) cr Unit nIndex ix where nIndex :: T (n ': indexShape ++ '[1]) ('Typ 'Int w) nIndex = DirectBroadcast Unit Unit ((:*) n Unit) (is .+. (:*) (natSat @1) Unit) range -- directBroadcast0 :: forall n s t. KnownShape s => KnownNat n => T s t -> T (n:s) t -- directBroadcast0 = appRUnit @s #> DirectBroadcast Unit ((:*) (natSat @n) Unit) (typeSShape @s) Unit -- broadcastIndexMany :: forall n containerShape indexShape w. -- KnownBits w => -- Sat KnownNat n -> -- SShape containerShape -> -- SShape indexShape -> -- IndexTensor indexShape '[n] w -> -- IndexTensor (containerShape ++ indexShape) (containerShape ++ '[n]) w -- broadcastIndexMany _ Unit _ x = x -- broadcastIndexMany n ((:*) m@Sat cs) is x = -- knownSShape (cs .+. (*:) is (sListLenAsNat (cs *: n))) ?> -- -- (m : cs ++ is ++ '[(Length (m : cs ++ [n]))]) -- (broadcastIndex m ((*:) cs n) (cs .+. is) $ -- -- (m : (cs ++ is ++ '[Length (cs ++ [n])])) -- (appAssocS cs is ((:*) (sListLenAsNat (cs *: n)) Unit) #> -- -- (m : cs ++ is ++ '[Length (cs ++ [n])]) -- directBroadcast0 $ -- -- (cs ++ is ++ '[Length (cs ++ [n])]) -- broadcastIndexMany n cs is x)) -- -- is -- Product (filterSpatialShape ++ '[inChannels, outChannels * n]) -- Product ((filterSpatialShape ++ '[inChannels, outChannels]) ++ '[n]) axisOpInputShape :: Axis1Op s1 t s2 u -> SShape s1 axisOpInputShape o = case o of ArgMax n -> HSingle n OneHot _n -> Unit ReduceOp n _ -> HSingle n ReverseT n -> HSingle n SliceOp _ n _ _ -> HSingle n AccessOp n _ -> HSingle n unopInputShape :: UnOp s t s' t' -> SShape s unopInputShape (Diag n) = n :* Unit unopInputShape Cast = Unit unopInputShape (Axis1Op s o) = axisOpInputShape o .+. s unopInputShape StopGradient = Unit unopInputShape (Num1Op _) = Unit unopInputShape (Float1Op _) = Unit unopInputShape (ExpM n) = n :* n :* Unit unopInputShape (ZeroTriangle n _ _) = n :* n :* Unit unopInputShape Conjugate = Unit unopInputShape RealPart = Unit protoBroadcast :: forall n s t. Unique -- unique identifier marking the variable tensor which will be marking inputs (not to broadcast). -> Bool -- how to expand the noise? (If True use different noise for all indices) -> Sat KnownNat n -- added dimension's size -> (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> G (T (n ': s') t')) -- recursive case -> (forall s' t'. T s' t' -> G Bool) -- test if we're done -> STyp t -- representation of the type -> SShape s -- representation of the shape -> T s t -- tensor (expression) to broadcast -> G (T (n ': s) t) -- return broadcated expression (on 1st position) protoBroadcast u varyNoise n@(Comp Dict) rec finished ty s tensor = do isFinished <- finished tensor case isFinished of True -> simpleBC False -> knownTyp ty $ case tensor of BroadcastT {} -> error "BroadcastT case remaining, this should have been dealt with by generateBC" MapT {} -> error "MapT case remaining, this should have been dealt with by generateBC" ZipT {} -> error "ZipT case remaining, this should have been dealt with by generateBC" Zip3T {} -> error "Zip3T case remaining, this should have been dealt with by generateBC" Softmax bs@Sat m@Sat x -> prodAssocS n bs m #> do x' <- rec (typeSShape) x return (reshapeAuto (Softmax (satMul n bs) m (reshapeAuto x'))) DirectBroadcast s0 s1 s2 s3 x -> do x' <- (rec (s0 .+. s2) x) return (DirectBroadcast (n :* s0) s1 s2 s3 x') GatherND cs es is x ix -> do xFinished <- finished x case xFinished of True -> GatherND cs es (n :* is) x <$> (rec (is *: sListLenAsNat cs) ix) False -> do ix' <- rec (is *: sListLenAsNat cs) ix x' <- (rec (cs .+. es) x) return (GatherND (n :* cs) es (n :* is) x' (broadcastIndex' n (sListLenAsNat cs) is ix')) Noise v s0 s1 x -> if varyNoise then return (Noise v (n :* s0) s1 x) else simpleBC -- When varying noise, then we extend the shape of the noise (so -- more stuff is sampled), otherwise we copy the noise using simple -- broadcasting Pool bs@Sat window pt numChans outSpatial x -> (knownSShape (zipWithMulSShapes window outSpatial *: numChans) ?> (prodAssocS n bs (productS (zipWithMulSShapes window outSpatial *: numChans)) #> (prodAssocS n bs (productS (outSpatial *: numChans)) #> do x' <- (rec typeSShape x) return $ (reshapeFrom (satMul n bs :* outSpatial *: numChans) $ Pool (satMul n bs) window pt numChans outSpatial (reshapeAuto x'))))) Where cond x y -> Where <$> (rec s cond) <*> (rec s x) <*> (rec s y) If cond x y -> do condFinished <- finished cond case condFinished of True -> If cond <$> (rec s x) <*> (rec s y) False -> error "broadcast on 'if' condition not implemented" T _ -> error "panic: broadcast constant should be finished!" Unbroadcast p@Sat u' x | u == u' -> return $ case testEq p n of Nothing -> UnOp (error "panic.unbroadcast.unit") Unit x Just Refl -> x | otherwise -> knownSShape s ?> do x' <- (rec (p :* s) x) return (Unbroadcast p u' (transpose01 x')) -- An uncomplete broadcast (in another dimension). MatMul s0 a@Sat b@Sat c@Sat x y -> do yFinished <- finished y case (s0,yFinished) of (Unit,True) -> do -- this optimisation is absolutely critical to implement dense -- layers efficiently (at least with TF 1.3). (about 10x performance increase) x' <- (rec (a :* b :* Unit) x) return $ inflate2 (MatMul Unit (satMul n a) b c (flatten2 x') y) _ -> MatMul (n :* s0) a b c <$> (rec (s0 .+. a :* b :* Unit) x) <*> (rec (s0 .+. b :* c :* Unit) y) BinOp op s0 s1 t1 s2 t2 x y -> knownTyp t1 $ knownTyp t2 $ do BinOp op (n :* s0) s1 t1 s2 t2 <$> (rec (s0 .+. s1) x) <*> (rec (s0 .+. s2) y) UnOp op s0 x -> UnOp op (n :* s0) <$> (rec (s0 .+. unopInputShape op) x) Gather is s0 m s1 x ix -> do xFinished <- finished x case (s0,xFinished) of -- this optimisation is important to get efficient embeddings (???) (Unit,True) -> Gather (n :* is) Unit m s1 x <$> (rec is ix) _ -> Gather is (n :* s0) m s1 <$> (rec (s0 .+. m :* s1) x) <*> (rec (s0 .+. is) ix) Transpose s0 t x -> Transpose (n :* s0) (PermSkip t) <$> (rec s0 x) ReshapeFrom s0 x -> reshapeFrom (n :* s0) <$> (rec s0 x) Concat s0 s1 xs -> do Concat (n :* s0) s1 <$> hTraverse (\(Catable m x) -> Catable m <$> (rec (s0 .+. m :* s1) x)) xs Convolution bs@(Sat) inChans outChans filterShape s0 x filters -> do filtersFinished <- finished filters xFinished <- finished x case (filtersFinished,xFinished) of (True,_) -> prodAssocS n bs (productS (s0 *: inChans)) #> prodAssocS n bs (productS (s0 *: outChans)) #> knownSShape (s0 *: inChans) ?> do x' <- (rec typeSShape x) return $ reshapeFrom (satMul n bs :* s0 *: outChans) (Convolution (satMul n bs) inChans outChans filterShape s0 (reshapeAuto x') filters) (_,True) -> knownSShape (filterShape .+. inChans :* outChans :* Unit) ?> knownSShape (bs :* s0 .+. outChans :* Unit) ?> do filters' <- rec typeSShape filters return $ transposeN' $ reshapeProven (ANat bs !:* AShape s0 *:! (ANat outChans :*: ANat n)) ((ANat bs !:* AShape s0 *:! ANat outChans) *:! ANat n) $ Convolution bs inChans (outChans `satMul` n) filterShape s0 x $ reshapeProven ((AShape filterShape :++: (ANat inChans !:* Single (ANat outChans))) *:! ANat n) (AShape filterShape :++: ANat inChans !:* Single (ANat outChans :*: ANat n)) $ transposeN $ filters' _ -> error "broadcast on both convolution filter and data not implemented" _ -> error "protoBroadcast: unhandled case" where simpleBC :: G (T (n ': s) t) simpleBC = appRUnit @s #> return (DirectBroadcast Unit (n :* Unit) s Unit tensor) inversePerm :: Permutation a b -> Permutation b a inversePerm PermId = PermId inversePerm (PermSkip x) = PermSkip (inversePerm x) inversePerm PermSwap = PermSwap inversePerm (PermTrans x y) = PermTrans (inversePerm y) (inversePerm x) permToFun :: Permutation s t -> Integer -> Integer permToFun = \case PermId -> \x -> x PermTrans a b -> permToFun b . permToFun a PermSwap -> \case 0 -> 1 1 -> 0 x -> x PermSkip p -> \case 0 -> 0 x -> permToFun p (x-1) + 1 reshapeAuto :: forall s s0 t. KnownShape s0 => Product s ~ Product s0 => T s0 t -> T s t reshapeAuto = reshapeFrom typeSShape reshapeProven :: forall s s0 t n. ShapeX s0 n -> ShapeX s n -> T s0 t -> T s t reshapeProven s1 s2 = case decideProductEq s1 s2 of Refl -> knownSShape (exprSShape s1) ?> reshapeAuto reshapeTo :: forall s s0 t proxy. KnownShape s0=> Product s ~ Product s0 => proxy s -> T s0 t -> T s t reshapeTo _ = reshapeAuto reshapeFrom :: forall s s0 t. Product s ~ Product s0 => SShape s0 -> T s0 t -> T s t reshapeFrom _ (ReshapeFrom s1 x) = ReshapeFrom s1 x -- avoid reshaping over and over reshapeFrom s0 x = ReshapeFrom s0 x type BatchedPlaceholders n ps = Placeholders (BPH n ps) type BPH n ps = (Ap (FMap (ConsSh n)) ps) -- | Batch the model (adding one dimension). mapPlaceHolders :: forall batchSize shapesAndTypes resShapesAndTypes. (KnownNat batchSize, KnownLen shapesAndTypes, KnownLen resShapesAndTypes, All KnownPlaceholder shapesAndTypes, All KnownPlaceholder resShapesAndTypes) => Unique -> Bool -> (Placeholders shapesAndTypes -> Placeholders resShapesAndTypes) -> BatchedPlaceholders batchSize shapesAndTypes -> (BatchedPlaceholders batchSize resShapesAndTypes) mapPlaceHolders u varyNoise f xs = broadcastPlacehoders @batchSize typeSList (f (unbroadcastPlacehoders @batchSize typeSList xs)) where unbroadcastPlacehoders :: forall n r. KnownNat n => SList r -> BatchedPlaceholders n r -> Placeholders r unbroadcastPlacehoders Unit Unit = Unit unbroadcastPlacehoders (_ :* ss) (PHT x :* xs') = PHT (Unbroadcast batchSize u x) :* unbroadcastPlacehoders @n ss xs' where batchSize = natSat @n broadcastPlacehoders :: forall n r. All KnownPlaceholder r => KnownNat n => SList r -> Placeholders r -> (BatchedPlaceholders n r) broadcastPlacehoders Unit Unit = Unit broadcastPlacehoders (_ :* ss) (PHT x :* xs) = let x' = BroadcastT (Just u) varyNoise (natSat @n) typeSShape x xs' = broadcastPlacehoders @n ss xs in (PHT x' :* xs') ---------------------------------------------------------------- -- Here start helper functions permN :: SList s -> Permutation (n ': s) (s ++ '[n]) permN Unit = PermId permN ((:*) _n s) = PermSwap `PermTrans` PermSkip (permN s) permN01 :: SList s -> Proxy m -> Proxy n -> Permutation (s ++ [m,n]) (s ++ [n,m]) permN01 Unit _ _ = PermSwap permN01 ((:*) _n s) m n = PermSkip (permN01 s m n) -- | Transposition. See the type for the permutation of dimensions. transposeN :: ∀ s n t. KnownNat n => KnownShape s => T (n ': s) t -> T (s ++ '[n]) t transposeN = doTranspose typeSShape (permN (typeSList @s)) -- | Transposition. See the type for the permutation of dimensions. transposeN' :: ∀ s n t. KnownNat n => KnownShape s => T (s ++ '[n]) t -> T (n ': s) t transposeN' = doTranspose (typeSShape @s *: (natSat @n)) (inversePerm (permN (typeSList @s))) -- | Transposition. See the type for the permutation of dimensions. transposeN01 :: ∀ s m n t. KnownNat n => KnownNat m => KnownShape s => T (s ++ [m,n]) t -> T (s ++ [n,m]) t transposeN01 = doTranspose (typeSShape @s .+. typeSShape @'[m,n]) (permN01 (typeSList @s) (Proxy @m) (Proxy @n)) -- | Transposition. See the type for the permutation of dimensions. transpose01 :: ∀ s m n t. KnownNat n => KnownNat m => KnownShape s => T (m ': n ': s) t -> T (n ': m ': s) t transpose01 = doTranspose typeSShape PermSwap doTranspose :: SShape s0 -> Permutation s0 s -> T s0 t -> T s t doTranspose _ p (Transpose sh' q x) = doTranspose sh' (PermTrans q p) x doTranspose sh p x = Transpose sh p x -- | Concatenate tensors with explicit shapes. concatT' :: ∀ s0 d1 d2 s1 t. KnownTyp t => SShape s0 -> Sat KnownNat d1 -> Sat KnownNat d2 -> SShape s1 -> T (s0 ++ (d1 ': s1)) t -> T (s0 ++ (d2 ': s1)) t -> T (s0 ++ ((d1+d2) ': s1)) t concatT' s0 d1@(Comp Dict) d2@(Comp Dict) s1 x y = Concat s0 s1 (Catable d1 x :* Catable d2 y :* Unit) -- | Reshape a tensor so that the first dimension is expanded into two. inflate2 :: ∀ m n s t. KnownTyp t => (KnownNat m, KnownNat n, KnownShape s) => Tensor (m*n ': s) t -> Tensor (m ': n ': s) t inflate2 = prodAssoc @m @n @(Product s) #> reshape -- | Reshape a tensor so that the first two dimensions are collapsed flatten2 :: ∀ m n s t. KnownTyp t => (KnownNat m, KnownNat n, KnownShape s) => Tensor (m ': n ': s) t -> Tensor (m*n ': s) t flatten2 = prodAssoc @m @n @(Product s) #> reshape reshape :: ∀ s2 s1 t. KnownShape s1 => KnownShape s2 => Product s1 ~ Product s2 => Tensor s1 t -> Tensor s2 t reshape = reshapeAuto -- | range[i] = i range :: forall n w. KnownNat n => KnownBits w => T '[n] ('Typ 'Int w) range = T (Range (natSat @n)) ================================================ FILE: TypedFlow/Haskell.hs ================================================ {-| Module : TypedFlow.Haskell Description : Generation of computation graph using tensorflow haskell. Copyright : (c) Jean-Philippe Bernardy, 2017 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental -} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UndecidableSuperClasses #-} {-# LANGUAGE UnicodeSyntax #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module TypedFlow.Haskell where import Data.Type.Equality import Data.List (genericReplicate) import GHC.TypeLits import Control.Monad.State import TypedFlow.Types import TypedFlow.Types.Proofs import TypedFlow.Abstract (newId, permToFun, unopInputShape) import TypedFlow.Memo import System.Mem.StableName import System.IO.Unsafe import qualified Data.Int as Backend import qualified TensorFlow.Core as Backend import qualified TensorFlow.GenOps.Core as BackCore import qualified TensorFlow.Minimize as Backend import qualified TensorFlow.Ops as Backend import qualified TensorFlow.NN as Backend -- import qualified TensorFlow.Variable as Backend import qualified TensorFlow.Tensor import qualified Data.IntMap as IM import Data.IntMap (IntMap) type BackendShape = BackendTensor ('Typ 'Int 'B32) type BackendTensor t = Backend.Tensor Backend.Build (HaskType t) type BackendVariable t = Backend.Tensor Backend.Ref (HaskType t) type BackendTensorType t = Backend.TensorType (HaskType t) shapeFromType :: ∀ (s :: Shape). KnownShape s => BackendShape shapeFromType = shapeVector (typeSShape @s) -- | Show a shape, but "None" is replaced by "-1" shapeVector :: forall (s::Shape) proxy. All KnownNat s => SList' proxy s -> BackendShape shapeVector s = shapeFromList (shapeToList'' s) permToTensor :: SShape s -> Permutation s t -> Backend.Tensor Backend.Build Backend.Int32 permToTensor s p = Backend.vector (map (fromInteger . permToFun p) [0.. sListLength s]) shapeFromList :: [Integer] -> BackendShape shapeFromList = Backend.vector . map convertNone showShapeLen :: ∀ (s::Shape). KnownLen s => Backend.Int32 showShapeLen = fromIntegral (listTypeLen @ s) convertNone :: Num a => Integer -> a convertNone n = (if n == 514229 then (-1) else fromIntegral n) -- runWithFeeds data BT (s :: Shape) (t :: Typ) where BT :: forall s t. (BackendTensor t) -> BT s t data HState = HState {genVars :: IntMap Var ,genPureTable :: SNMap22 Shape Typ T BT -- alternative: use tensorRefFromName and make this closer to the python backed. } type BM a = Backend.BuildT (StateT HState (State GState)) a data Var = forall s t v. TensorFlow.Tensor.TensorKind v => Var (SShape s) (STyp t) (Backend.Tensor v (HaskType t)) initializedVariable :: forall s a. KnownShape s => KnownTyp a => T s a -> BM (Ref s a) initializedVariable initVal = do BT i <- interpretPure initVal x <- lift (lift newId) v <- backendTensor (typeSTyp @a) $ Backend.initializedVariable i let var = (Var (typeSShape @s) (typeSTyp @a) v) lift (modify $ \HState{..} -> HState {genVars = IM.insert (fromIntegral x) var genVars,..}) return (Ref (fromIntegral x) typeSShape typeSTyp ) placeholder :: forall s a. SShape s -> STyp a -> BM (Ref s a) placeholder s t = do x <- lift (lift newId) ph <- backendTensor t $ Backend.placeholder (Backend.Shape (map convertNone $ shapeToList' s)) let var = (Var s t ph) lift (modify $ \HState{..} -> HState {genVars = IM.insert (fromIntegral x) var genVars,..}) return (Ref (fromIntegral x) s t ) interpGen :: Gen a -> BM a interpGen (GPReturn x) = return x interpGen (GPVariable _trainable _name initVal) = initializedVariable initVal interpGen (GPPlaceholder s t _name) = placeholder s t interpGen (GPModify _ _) = error "GPModify: TODO" interpGen (GPState f) = lift (lift (state f)) interpGen (GPBind a b) = do x <- interpGen a interpGen (b x) listProxyLen :: forall proxy s. KnownLen s => proxy s -> Integer listProxyLen _ = listTypeLen @s -- genDistr :: forall s s0 t. KnownTyp t => Distribution s t -> SShape s0 -> SShape s -> DOC -- genDistr d sh s1 = case d of -- TruncatedNormalD stddev -> funcall "tf.truncated_normal" -- [showSShape (sh .+. s1), named "stddev" (float stddev), named "dtype" (showTyp @t)] -- UniformD low high -> funcall "tf.random_uniform" [showSShape (sh .+. s1) -- ,named "minval" (float low) -- ,named "maxval" (float high) -- ,named "dtype" (showTyp @t)] -- OrthogonalD -> -- funcall' (funcall "tf.orthogonal_initializer" [named "dtype" (showTyp @t)]) [named "shape" (showSShape (sh .+. s1))] knownNumeric :: forall t k. KnownNumeric t => (KnownTyp t => Num (HaskType t) => Backend.OneOf '[Backend.Int32, Float, Double] (HaskType t) => k) -> k knownNumeric = knownNumeric' (typeSTyp @t) knownNumeric' :: forall t k. KnownNumeric t => STyp t -> (KnownTyp t => Num (HaskType t) => Backend.OneOf '[Backend.Int32, Float, Double] (HaskType t) => k) -> k knownNumeric' (STyp tk tb Refl) k = case tk of SFloat -> case tb of SB32 -> k SB64 -> k SBool -> error "TFNumeric bug" SInt -> case tb of SB32 -> k SB64 -> error "missing in tensorflow: int64 is not supported in matmul T_T" knownFloatingB :: forall t k. (KnownTyp t, TypKind t ~ 'Float) => (Backend.OneOf '[Float, Double] (HaskType t) => k) -> k knownFloatingB k = case bitsVal @(TypBits t) of SB32 -> k SB64 -> k knownInt :: forall t k. (KnownTyp t, TypKind t ~ 'Int) => (Backend.OneOf '[Backend.Int32, Backend.Int64] (HaskType t) => k) -> k knownInt k = case bitsVal @(TypBits t) of SB32 -> k SB64 -> k backendTensor :: STyp t -> (Backend.TensorType (HaskType t) => k) -> k backendTensor (STyp SFloat SB32 Refl) k = k backendTensor (STyp SInt SB64 Refl) k = k backendTensor (STyp SBool _ Refl) k = k backendTensor (STyp SFloat SB64 Refl) k = k backendTensor (STyp SInt SB32 Refl) k = k backendTensor' :: forall t k proxy. KnownTyp t => proxy t -> (Backend.TensorType (HaskType t) => k) -> k backendTensor' _ = backendTensor (typeSTyp @t) runUnOp :: forall s s1 t s2 u. KnownTyp u => KnownTyp t => BackendTensorType u => SShape s -> UnOp s1 t s2 u -> BT (s++s1) t -> BT (s++s2) u runUnOp sL op (BT x) = backendTensor (typeSTyp @t) $ case op of SliceOp _ sR lo hi -> BT $ BackCore.slice x (shapeFromList (replicate (sListLen sL) 0 ++ [lo] ++ replicate (sListLen sR) 0)) (shapeFromList (shapeToList' sL ++ [hi-lo] ++ (shapeToList' sR))) Axis1Op aop -> case aop of (ArgMax _ _) -> knownNumeric @t $ knownInt @u $ BT $ BackCore.argMax x (Backend.scalar sLLen) (OneHot _) -> knownNumeric @u $ knownInt @t $ BT $ Backend.oneHot x (Backend.scalar sLLen) (Backend.scalar 1) (Backend.scalar 0) ReduceOp _ _sR rop -> knownNumeric @t $ case rop of Max -> BT $ BackCore.max x redindices Min -> BT $ BackCore.min x redindices Sum -> BT $ Backend.sum x redindices Mean -> BT $ Backend.mean x redindices where redindices = (Backend.vector [fromIntegral (sListLen sL) :: Backend.Int32 ]) StopGradient -> BT $ BackCore.stopGradient x Cast -> BT $ Backend.cast x (Num1Op numop) -> knownNumeric @t $ case numop of Square -> BT (Backend.mul x x) Negate -> BT (Backend.neg x) Sign -> BT (Backend.sign x) Abs -> BT (Backend.abs x) FloorMod -> BT (Backend.floorMod x) Float1Op flop -> knownFloatingB @t $ knownFloating @(TypBits u) $ knownFloatingB @u $ case flop of Tanh -> BT (BackCore.tanh x) Sin -> BT (BackCore.sin x) Exp -> BT (BackCore.exp x) Sigmoid -> BT (BackCore.sigmoid x) Relu -> BT (BackCore.relu x) Floor -> BT (BackCore.floor x) Round -> BT (BackCore.round x) Cos -> BT (BackCore.cos x) Log -> BT (BackCore.log x) Asin -> BT (BackCore.asin x) Acos -> BT (BackCore.acos x) Sinh -> BT (BackCore.sinh x) Cosh -> BT (BackCore.cosh x) Asinh -> BT (BackCore.asinh x) Acosh -> BT (BackCore.acosh x) Atan -> BT (BackCore.atan x) Atanh -> BT (BackCore.atanh x) Sqrt -> BT (BackCore.sqrt x) HardSigmoid -> error "Haskell: no hard sigmoid defined yet" ClipByValue lo hi -> BT $ BackCore.clipByValue x (Backend.scalar $ realToFrac lo) (Backend.scalar $ realToFrac hi) Diag _ -> BT $ BackCore.batchMatrixDiag x where sLLen = fromIntegral (sListLen sL) :: Backend.Int32 interpretPure :: forall s t. KnownTyp t => KnownShape s => T s t -> BM (BT s t) interpretPure x = do let sn = unsafePerformIO $ makeStableName x mv <- snMap22Lookup sn <$> lift (gets genPureTable) case mv of Just v -> return v Nothing -> do e <- interpretPure' (\s x' -> knownSShape s $ interpretPure x') typeSShape x lift $ modify (\g -> g {genPureTable = (snMap22Insert (KV sn e)) (genPureTable g)}) return e interpNilOp :: forall s t. Backend.TensorType (HaskType t) => NilOp s t -> BM (BT s t) interpNilOp = \case Constant c -> return $ BT $ Backend.scalar c Range n@Sat -> knownNumeric @t $ return $ let start,limit,delta :: HaskType t start = 0 limit = fromIntegral $ natVal n delta = 1 in BT $ Backend.range (Backend.scalar start) (Backend.scalar limit) (Backend.scalar delta) Variable (Ref r sr tr) -> do tbl <- lift (gets genVars) case IM.lookup r tbl of Just (Var sx tx x) -> case (testEq sx sr, testEq tx tr) of (Just Refl, Just Refl) -> return (BT (Backend.expr x)) _ -> error "panic: variable does not have the expected type" _ -> error "panic: variable not found" interpretPure' :: forall s t. KnownTyp t => (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> BM (BT s' t')) -> SShape s -> T s t -> BM (BT s t) interpretPure' rec sR = knownSShape sR $ backendTensor (typeSTyp @t) $ \case Unbroadcast{} -> error "broadcasting operation did not complete!" DirectBroadcast s0 s1 s2 s3 x -> do BT recx <- rec (s0 .+. s2) x let expandedShape = shapeFromList (concat [shapeToList' s0, genericReplicate (sListLength s1) 1 ,shapeToList' s2, genericReplicate (sListLength s3) 1 ]) targetShape = shapeFromList sR return $ BT $ BackCore.broadcastTo (Backend.reshape recx expandedShape) targetShape -- Noise noiseId s0 s1 x -> do -- return $ (genDistr x s0 s1) <+> (text "# " <> integer noiseId) T op -> interpNilOp op Where c x y -> do BT rc <- rec typeSShape c BT rx <- rec typeSShape x BT ry <- rec typeSShape y return $ BT $ BackCore.select rc rx ry UnOp operation s0 x -> do recx <- rec (s0 .+. unopInputShape operation) x return (runUnOp s0 operation recx) MatMul s0 a b c x y -> do BT recx <- rec (s0 .+. a :* b :* Unit) x BT recy <- rec (s0 .+. b :* c :* Unit) y return $ knownNumeric @t $ BT $ BackCore.batchMatMul recx recy BinOp operation s0 s1 t s2 u x y -> knownSShape s0 $ knownSShape s1 $ knownSShape s2 $ knownProduct' s0 $ do BT recx <- rec (s0 .+. s1) x BT recy <- rec (s0 .+. s2) y let reshx = backendTensor t $ Backend.reshape recx (shapeVector (satProd s0 :* s1)) reshy = backendTensor u $ Backend.reshape recy (shapeVector (satProd s0 :* s2)) return $ case operation of Simple2Op sop -> case sop of Add -> knownNumeric @t $ BT $ Backend.add recx recy Divide -> knownNumeric @t $ BT $ BackCore.div recx recy Equal -> backendTensor u $ BT $ Backend.equal recx recy Subtract -> knownNumeric @t $ BT $ Backend.sub recx recy Multiply -> knownNumeric @t $ BT $ Backend.mul recx recy Minimum -> knownNumeric @t $ BT $ BackCore.minimum recx recy Maximum -> knownNumeric @t $ BT $ BackCore.maximum recx recy LessThan -> knownNumeric' u $ BT $ BackCore.less recx recy -- WTF moment: the arguments do not seem to be in the same order in python as in haskell -- python: https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits -- haskell: https://tensorflow.github.io/haskell/haddock/tensorflow-core-ops-0.2.0.0/TensorFlow-GenOps-Core.html#v:sparseSoftmaxCrossEntropyWithLogits SparseSoftmaxCrossEntropyWithLogits -> case t of STyp SInt SB32 Refl -> knownFloatingB @t $ BT $ fst $ BackCore.sparseSoftmaxCrossEntropyWithLogits reshy reshx SoftmaxCrossEntropyWithLogits -> knownFloatingB @t $ BT $ fst $ BackCore.softmaxCrossEntropyWithLogits reshy reshx -- SigmoidCrossEntropyWithLogits -> knownFloatingB @t $ BT $ Backend.sigmoidCrossEntropyWithLogits recy recx -- type is not as general as necessary ReshapeFrom s t -> do BT rt <- rec s t return $ BT $ BackCore.reshape rt (shapeVector sR) Concat s0 s1 xs -> do let go :: forall s0 s1 ns. SShape s0 -> SShape s1 -> NP (Catable s0 s1 t) ns -> BM [BackendTensor t] go _ _ Unit = return [] go s0' s1' (Catable n y :* ys) = do BT y' <- rec (s0' .+. n :* s1') y (y' :) <$> go s0' s1' ys rxs <- go s0 s1 xs return $ BT $ Backend.concat (Backend.scalar (fromIntegral (sListLength s0))) rxs Transpose s p x -> do BT rx <- rec s x return $ BT $ Backend.transpose rx (permToTensor s p) -- Gather indexShape s0 m s1 x ix -> do -- rx <- rec (s0 .+. ((:*) m s1)) x -- rix <- rec indexShape ix -- return (func "tf.gather" [rx, rix] []) -- GatherND containerShape elementShape indexShape x ix -> do -- rx <- rec (containerShape .+. elementShape) x -- rix <- rec (indexShape *: (sListLenAsNat containerShape)) ix -- return (func "tf.gather_nd" [rx, rix] []) Convolution bs inChans outChans filterShape s0 x filters -> do BT recx <- rec (bs :* (s0 *: inChans)) x BT recFilters <- rec (filterShape .+. inChans :* outChans :* Unit) filters case filterShape of _width :* _height :* Unit -> return $ BT $ knownFloatingB @t $ BackCore.conv2D recx recFilters _ -> error "TypedFlow Haskell backend: convolution on an unsupported number of dims" -- Pool bs window typ numChans outSpatial x -> do -- rx <- rec ((:*) bs (zipWithMulSShapes window outSpatial .+. (:*) numChans Unit)) x -- return (func "tf.nn.pool" -- [rx, showSShape window, typ', text (show ("SAME" :: String))] -- [("strides", showSShape window)]) -- where typ' = text $ (show $ case typ of MaxPool -> "MAX"; AvgPool -> "AVG" :: String) -- -- where rec :: forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> DOC -- -- rec = generatePure' ================================================ FILE: TypedFlow/Layers/Core.hs ================================================ {-| Module : TypedFlow.Layers.Core Description : Core layers and combinators. Copyright : (c) Jean-Philippe Bernardy, 2017 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental -} {-# LANGUAGE CPP #-} #if __GLASGOW_HASKELL__ >= 806 {-# LANGUAGE NoStarIsType #-} #endif {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE TypeInType #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UnicodeSyntax #-} {-# LANGUAGE PatternSynonyms #-} module TypedFlow.Layers.Core ( -- * Dense DenseP(..), dense, (#), -- * Dropout DropProb(..), mkMask, mkDropout, mkDropouts, -- * Embedding EmbeddingP(..), embedding, -- * Convolutional ConvP(..), conv, conv', {-convValid,-} maxPool1D, maxPool2D, glu ) where import Prelude hiding (RealFrac(..)) import GHC.TypeLits import TypedFlow.TF import TypedFlow.Types import TypedFlow.Types.Proofs import TypedFlow.Abstract import Control.Monad.State (gets) import Data.Monoid ((<>)) --------------------- -- Linear functions -- | A dense layer is a linear function form a to b: a transformation matrix and a bias. data DenseP t a b = DenseP {denseWeights :: Tensor '[a,b] t ,denseBiases :: Tensor '[b] t} ----------------------- -- Feed-forward layers -- | Parameters for the embedding layers newtype EmbeddingP numObjects embeddingSize t = EmbeddingP (Tensor '[numObjects, embeddingSize] t) instance (KnownNat numObjects, KnownTyp b, KnownNat embeddingSize) => KnownTensors (EmbeddingP numObjects embeddingSize b) where travTensor f s (EmbeddingP p) = EmbeddingP <$> travTensor f s p instance (KnownNat numObjects, KnownBits b, KnownNat embeddingSize) => ParamWithDefault (EmbeddingP numObjects embeddingSize ('Typ 'Float b)) where defaultInitializer = EmbeddingP <$> (noise $ UniformD (-0.05) 0.05) instance (KnownNat numObjects, KnownBits b, KnownNat embeddingSize) => ParamWithDefault (EmbeddingP numObjects embeddingSize ('Typ 'Cmplx b)) where defaultInitializer = EmbeddingP <$> (mkComplex <$> (noise $ UniformD (-0.05) 0.05) <*> (noise $ UniformD (-0.05) 0.05)) -- | embedding layer embedding :: ∀ embeddingSize numObjects t. KnownNat embeddingSize => KnownNat numObjects => EmbeddingP numObjects embeddingSize t -> Tensor '[] Int32 -> Tensor '[embeddingSize] t embedding (EmbeddingP param) input = gather param input instance (KnownNat a, KnownNat b, KnownTyp t) => KnownTensors (DenseP t a b) where travTensor f s (DenseP x y) = DenseP <$> travTensor f (s<>"_w") x <*> travTensor f (s<>"_bias") y instance (KnownNat n, KnownNat m, KnownFloat b) => ParamWithDefault (DenseP b n m) where defaultInitializer = DenseP <$> glorotUniform <*> (noise $ TruncatedNormalD 0.1) -- | Dense layer (Apply a linear function) (#), dense :: ∀m n t. KnownNat n => KnownNat m => KnownNumeric t => DenseP t n m -> Tensor '[n] t -> Tensor '[m] t (DenseP weightMatrix bias) # v = (weightMatrix ∙ v) + bias dense = (#) -- | A drop probability. (This type is used to make sure one does not -- confuse keep probability and drop probability) data DropProb = DropProb Float -- | Generate a dropout function. The mask applied by the returned -- function will be constant for any given call to mkDropout. See -- 'noise' for the sampling behaviour. mkDropout :: forall s t. KnownShape s => KnownFloat t => DropProb -> Gen (Tensor s t -> Tensor s t) mkDropout d = (⊙) <$> mkMask d -- | Generate a 0-1 mask with given probability, suitable for dropout, -- or all ones if not in training phase. See 'noise' for the sampling -- behaviour. mkMask :: forall s t. KnownShape s => KnownFloat t => DropProb -> Gen (Tensor s t) mkMask (DropProb dropProb) = do let keepProb = 1 - dropProb let isTraining = genTrainingPlaceholder r <- noise $ UniformD keepProb (1 + keepProb) return $ if_ isTraining (floor r ⊘ constant (knownAlgebraic @t $ realToFrac keepProb)) ones newtype EndoTensor t s = EndoTensor (Tensor s t -> Tensor s t) -- | Generate a dropout function for an heterogeneous tensor vector. mkDropouts :: KnownFloat t => KnownLen shapes => All KnownShape shapes => DropProb -> Gen (HTV t shapes -> HTV t shapes) mkDropouts d = appEndoTensor <$> mkDropouts' typeSList where mkDropouts' :: forall shapes t. KnownFloat t => All KnownShape shapes => SList shapes -> Gen (NP (EndoTensor t) shapes) mkDropouts' Unit = return Unit mkDropouts' (_ :* rest) = do x <- mkDropout d xs <- mkDropouts' rest return (EndoTensor x :* xs) appEndoTensor :: NP (EndoTensor t) s -> HTV t s -> HTV t s appEndoTensor Unit Unit = Unit appEndoTensor (EndoTensor f :* fs) (F x :* xs) = F (f x) :* appEndoTensor fs xs ------------------------ -- Convolutional layers data ConvP t outChannels inChannels filterSpatialShape = ConvP (T (filterSpatialShape ++ '[inChannels,outChannels]) t) (T '[outChannels] t) instance (KnownNat outChannels,KnownNat inChannels, KnownShape filterSpatialShape, KnownFloat t) => ParamWithDefault (ConvP t outChannels inChannels filterSpatialShape) where defaultInitializer = prodHomo @filterSpatialShape @'[inChannels, outChannels] #> prodAssoc @(Product filterSpatialShape) @inChannels @outChannels #> knownAppend @filterSpatialShape @'[inChannels,outChannels] ?> knownProduct @filterSpatialShape ?> ConvP <$> (reshape <$> i) <*> pure (knownAlgebraic @t (constant 0.1)) where i :: Gen (T '[Product filterSpatialShape*inChannels,outChannels] t) i = knownProduct @filterSpatialShape ?> glorotUniform instance (KnownNat outChannels,KnownNat inChannels, KnownShape filterSpatialShape, KnownAlgebraic t) => KnownTensors (ConvP t outChannels inChannels filterSpatialShape) where travTensor f s (ConvP x y) = knownAppend @filterSpatialShape @'[inChannels,outChannels] ?> (ConvP <$> travTensor f (s<>"_filters") x <*> travTensor f (s <> "_biases") y) -- | Size-preserving convolution layer conv' :: forall s outChannels filterSpatialShape inChannels t. KnownShape s => KnownNat inChannels => KnownNat outChannels => KnownShape filterSpatialShape => KnownAlgebraic t => Length filterSpatialShape <= 3 => Length filterSpatialShape ~ Length s => ConvP t outChannels inChannels filterSpatialShape -> T (s ++ '[inChannels]) t -> T (s ++ '[outChannels]) t conv' (ConvP filters bias) input = mapTT @s (+bias) (convolution @outChannels @filterSpatialShape @inChannels @s input filters) conv :: forall outChannels filterSpatialShape inChannels s t. KnownShape s => KnownNat inChannels => KnownNat outChannels => KnownShape filterSpatialShape => KnownAlgebraic t => Length filterSpatialShape <= 3 => (Length filterSpatialShape + 1) ~ Length s -- The ranks must match, but not necessarily the dimensions => (Last s ~ outChannels) => ConvP t outChannels inChannels filterSpatialShape -> T (Init s ++ '[inChannels]) t -> T s t conv = initLast' @s #> incrPos @(Length filterSpatialShape) #> lengthInit (typeSList @s) #> incrCong @(Length filterSpatialShape) @(Length (Init s)) #> knownInit @s ?> conv' @(Init s) -- -- | Convolution layers with no padding (applying the filter only on -- -- positions where the input is fully defined, aka "VALID" in -- -- tensorflow.) -- convValid :: forall outChannels filterSpatialShape inChannels s t. -- ((1 + Length filterSpatialShape) ~ Length s, -- Length filterSpatialShape <= 3, -- KnownLen filterSpatialShape) -- the last dim of s is the batch size -- => ConvP t outChannels inChannels filterSpatialShape -- ^ Parameters -- -> T ('[inChannels] ++ AddSpatialDims s filterSpatialShape) ('Typ 'Float t) -- ^ input -- -> (T ('[outChannels] ++ s) ('Typ 'Float t)) -- convValid (ConvP filters bias) input = convolutionValid input filters + bias -- | Gated Linear Unit -- See: Language Modeling with Gated Convolutional Networks -- https://arxiv.org/pdf/1612.08083.pdf glu :: forall n t. KnownBits t => KnownNat n => T '[n+n] ('Typ 'Float t) -> T '[n] ('Typ 'Float t) glu x = plusMono @n @n #> knownPlus @n @n ?> let gate, h :: T '[n] ('Typ 'Float t) gate = slice0 @0 @n x h = termCancelation @n @n #> slice0 @n @(n+n) x in sigmoid gate ⊙ h ================================================ FILE: TypedFlow/Layers/RNN/Attention.hs ================================================ {-| Module : TypedFlow.Layers.RNN.Attention Description : Attention combinators to be used with RNN cells Copyright : (c) Jean-Philippe Bernardy, 2018 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental -} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE TypeInType #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UnicodeSyntax #-} {-# LANGUAGE PatternSynonyms #-} module TypedFlow.Layers.RNN.Attention ( -- * Attention mechanisms -- ** Scoring functions AttentionScoring, multiplicativeScoring, AdditiveScoringP(..), additiveScoring, -- ** Attention functions AttentionFunction, uniformAttn, luongAttention, -- ** Attention combinators attentiveWithFeedback ) where import Prelude hiding (RealFrac(..)) import GHC.TypeLits import TypedFlow.TF import TypedFlow.Types import TypedFlow.Types.Proofs (appRUnit,(#>)) import TypedFlow.Layers.RNN.Base -- | An attention scoring function. This function should produce a -- score (between 0 and 1). type AttentionScoring t keySize valueSize = Tensor '[keySize] t -> Tensor '[valueSize] t -> Tensor '[] t -- | A function which attends to an external input. Typically a -- function of this type is a closure which has the attended input in -- its environment. This environment is interpreted as an associative -- memory form key to value. type AttentionFunction t keySize valueSize = T '[keySize] t -> T '[valueSize] t -- | @attnExample1 θ h st@ combines each element of the vector h with -- s, and applies a dense layer with parameters θ. The "winning" -- element of h (using softmax) is returned. uniformAttn :: ∀ valueSize m keySize t. KnownNat valueSize => KnownNat m => KnownFloat t => AttentionScoring t keySize valueSize -- ^ scoring function -> T '[] Int32 -- ^ length of the input -> T '[m,valueSize] t -- ^ input (what we're attending to) -> AttentionFunction t keySize valueSize uniformAttn score len hs key = c where xx,α :: T '[m] t xx = mapT (score key) hs α = softmax0 (mask ⊙ xx) c :: T '[valueSize] t c = hs ∙ α mask = cast (sequenceMask @m len) -- mask according to length -- | Add some attention to an RnnCell, and feed the attention vector to -- the next iteration in the rnn. (This follows the diagram at -- https://github.com/tensorflow/nmt#background-on-the-attention-mechanism -- commit 75aa22dfb159f10a1a5b4557777d9ff547c1975a). attentiveWithFeedback ::forall attSize cellSize inputSize w ss. KnownNat inputSize => KnownNat attSize => KnownLen ss => KnownTyp w => AttentionFunction w cellSize attSize -> RnnCell w ss (T '[inputSize+attSize] w) (T '[cellSize] w) -> RnnCell w ('[attSize] ': ss) (T '[inputSize ] w) (T '[attSize] w) attentiveWithFeedback attn cell = appRUnit @ss #> withFeedback (cell .-. timeDistribute attn) -- -- | LSTM for an attention model. The result of attention is fed to the next step. -- attentiveLstm :: forall attSize n x bs t. KnownNat bs => -- AttentionFunction t bs n attSize -> -- LSTMP t n (x+attSize) -> -- RnnCell t '[ '[attSize,bs], '[n,bs], '[n,bs] ] (Tensor '[x,bs] (Flt t)) (Tensor '[attSize,bs] (Flt t)) -- attentiveLstm att w = attentiveWithFeedback att (lstm w) -- | Luong attention function (following -- https://github.com/tensorflow/nmt#background-on-the-attention-mechanism -- commit 75aa22dfb159f10a1a5b4557777d9ff547c1975a). -- Essentially a dense layer with tanh activation, on top of uniform attention. luongAttention :: ∀ attnSize d m e w. KnownNat e => KnownNat d => KnownNat attnSize => KnownNat m => KnownFloat w => Tensor '[d+e,attnSize] w -- ^ weights for the dense layer -> AttentionScoring w e d -- ^ scoring function -> Tensor '[] Int32 -- ^ length of the input -> T '[m,d] w -- ^ inputs -> AttentionFunction w e attnSize luongAttention w scoring lens hs_ ht = let ct = uniformAttn scoring lens hs_ ht in (tanh (w ∙ (concat0 ct ht))) -- | Multiplicative scoring function multiplicativeScoring :: forall valueSize keySize t. KnownFloat t => KnownNat valueSize => KnownNat keySize => T [keySize,valueSize] t -- ^ weights -> AttentionScoring t keySize valueSize multiplicativeScoring w dt h = ir · h where ir :: T '[valueSize] t ir = w ∙ dt data AdditiveScoringP sz keySize valueSize t = AdditiveScoringP (Tensor '[1,sz] t) (Tensor '[keySize, sz] t) (Tensor '[valueSize, sz] t) instance (KnownNat n, KnownNat k, KnownNat v, KnownTyp t) => KnownTensors (AdditiveScoringP k v n t) where travTensor f s (AdditiveScoringP x y z) = AdditiveScoringP <$> travTensor f (s<>"_v") x <*> travTensor f (s<>"_w1") y <*> travTensor f (s<>"_w2") z instance (KnownNat n, KnownNat k, KnownNat v, KnownFloat t) => ParamWithDefault (AdditiveScoringP k v n t) where defaultInitializer = AdditiveScoringP <$> glorotUniform <*> glorotUniform <*> glorotUniform -- | An additive scoring function. See https://arxiv.org/pdf/1412.7449.pdf additiveScoring :: forall sz keySize valueSize t. KnownNat valueSize => KnownNat sz => KnownNat keySize => KnownFloat t => AdditiveScoringP sz keySize valueSize t -> AttentionScoring t valueSize keySize additiveScoring (AdditiveScoringP v w1 w2) dt h = r'' where w1h :: Tensor '[sz] t w1h = w1 ∙ h w2dt = w2 ∙ dt z' :: Tensor '[sz] t z' = tanh (w1h + w2dt) r'' = z' · squeeze0 v ================================================ FILE: TypedFlow/Layers/RNN/Base.hs ================================================ {-| Module : TypedFlow.Layers.RNN.Base Description : RNN cells, layers and combinators. Copyright : (c) Jean-Philippe Bernardy, 2017 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental -} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE TypeInType #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UnicodeSyntax #-} {-# LANGUAGE PatternSynonyms #-} module TypedFlow.Layers.RNN.Base ( -- * Cell Combinators RnnCell, simpleRnn, runCell, mkCell, stackRnnCells, (.-.), bothRnnCells, (.|.), withBypass, withFeedback, onStates, -- * Rnn Combinators Rnn, runRnn, stackRnns, (.--.), bothRnns,(.++.), -- * RNN unfolding functions timeDistribute, iterateCell, iterateCellBackward, iterateWithCull, -- * Monad-like interface for cell construction Component(..), bindC, returnC, -- rnnBackwardsWithCull, ) where import Prelude hiding (tanh,Num(..),Floating(..),floor) import GHC.TypeLits import TypedFlow.TF import TypedFlow.Types import TypedFlow.Types.Proofs -- import Data.Type.Equality -- import Data.Kind (Type,Constraint) -- | The RNN Component generalized monad. This can be used to build -- RNNs cells which do not follow the simple and usual "stacking" -- pattern. This is not a simple monad, because the indexing over -- states is non-uniform; see 'BindC'. newtype Component t (states::[Shape]) a = C {runC :: HTV t states -> (HTV t states , a)} -- Note: states are tensors only, because we need to index into them -- in the time dimension in iterateWithCull instance Functor (Component t states) where fmap = mapC mapC :: (a -> b) -> Component t s a -> Component t s b mapC f c = C $ \s -> let (s',x) = runC c s in (s', f x) -- | Unit of the Component monad. returnC :: a -> Component t '[] a returnC x = C $ \Unit -> (Unit,x) -- | Bind operation for Components. States are accumulated. bindC :: forall t s0 s1 a b. KnownLen s1 => Component t s0 a -> (a -> Component t s1 b) -> Component t (s1++s0) b bindC f g = C $ \(hsplit @s1 -> (s1,s0)) -> let (s0',x) = runC f s0 (s1',y) = runC (g x) s1 in (happ s1' s0',y) -- | A cell (one time-step) in an rnn. @state@ is the state propagated through time. type RnnCell t states input output = input -> Component t states output -- | An rnn. @n@ is the length of the time sequence. @state@ is the state propagated through time. type Rnn n b state input output = RnnCell b state (V n input) (V n output) -- | Run a cell runCell :: RnnCell t states input output -> (HTV t states,input) -> (HTV t states, output) runCell cell = uncurry (flip (runC . cell)) -- | Run an RNN, using a tensor as input. @n@ is the length of the time sequence. runRnn :: (KnownNat n,KnownShape s0, KnownShape s1, KnownTyp t1) => Rnn n t2 states (T s1 t1) (T s0 t0) -> (HTV t2 states, Tensor (n ': s1) t1) -> (HTV t2 states, Tensor (n ': s0) t0) runRnn l (s,x) = let x' = unstack0 x (s',y) = runCell l (s,x') in (s',stack0 y) -- | Run an RNN composed of a single RNN cell. simpleRnn :: KnownTyp t1 => KnownShape s1 => KnownShape s0 => KnownNat n => RnnCell t2 states (T s1 t1) (T s0 t0) -> (HTV t2 states, Tensor (n : s1) t1) -> (HTV t2 states, Tensor (n : s0) t0) simpleRnn = runRnn . iterateCell -- | Construct a cell from an arbitrary stateful function mkCell :: ((HTV t states,input) -> (HTV t states, output)) -> RnnCell t states input output mkCell cell = C . flip (curry cell) ---------------------- -- Lifting functions -- | Convert a pure function (feed-forward layer) to an RNN cell by -- ignoring the RNN state. timeDistribute :: (a -> b) -> RnnCell t '[] a b timeDistribute = constantOverSteps -- | Convert a pure function (feed-forward layer) to an RNN cell by -- ignoring the RNN state. constantOverSteps :: (a -> b) -> RnnCell t '[] a b constantOverSteps stateLess a = returnC (stateLess a) -------------------------------------- -- Combinators -- | Compose two rnn layers. This is useful for example to combine -- forward and backward layers. (.--.),stackRnns :: forall s1 s2 a b c n bits. KnownLen s2 => Rnn n bits s1 a b -> Rnn n bits s2 b c -> Rnn n bits (s2 ++ s1) a c stackRnns = stackRnnCells infixr .--. (.--.) = stackRnns -- | Compose two rnn layers in parallel. bothRnns,(.++.) :: forall s1 s2 a b c n bits t. KnownTyp t => KnownLen s1 => KnownLen s2 => KnownNat n => KnownNat b => KnownNat c => Rnn n bits s1 a (T '[b] t) -> Rnn n bits s2 a (T '[c] t) -> Rnn n bits (s2 ++ s1) a (T ('[b+c]) t) bothRnns f g x = f x `bindC` \y -> g x `bindC` \z -> returnC (concat0 <$> y <*> z) infixr .++. (.++.) = bothRnns -- | Apply a function on the cell state(s) before running the cell itself. onStates :: (HTV t xs -> HTV t xs) -> RnnCell t xs a b -> RnnCell t xs a b onStates f cell x = C $ \h -> do runC (cell x) (f h) -- | Stack two RNN cells (LHS is run first) stackRnnCells, (.-.) :: forall s0 s1 a b c t. KnownLen s1 => RnnCell t s0 a b -> RnnCell t s1 b c -> RnnCell t (s1 ++ s0) a c stackRnnCells l1 l2 x = l1 x `bindC` l2 (.-.) = stackRnnCells -- | Compose two rnn cells in parallel. bothRnnCells, (.|.) :: forall s0 s1 a b c t bits. KnownLen s0 => KnownLen s1 => KnownBits bits => KnownNat b => KnownNat c => RnnCell t s0 a (T '[b] (Flt bits)) -> RnnCell t s1 a (T '[c] (Flt bits)) -> RnnCell t (s1 ++ s0) a (T '[b+c] (Flt bits)) bothRnnCells l1 l2 x = l1 x `bindC` \y -> l2 x `bindC` \z -> returnC (concat0 y z) (.|.) = bothRnnCells -- | Run the cell, and forward the input to the output, by -- concatenation with the output of the cell. This bypass is sometimes -- called a 'highway' in the literature. withBypass :: forall x y t b s0. KnownNat x => KnownNat y => KnownLen s0 => KnownTyp t => RnnCell b s0 (T '[x] t) (T '[y] t) -> RnnCell b s0 (T '[x] t) (T '[x+y] t) withBypass cell x = appRUnit @s0 #> cell x `bindC` \y -> returnC (concat0 x y) -- | Run the cell, and feeds its output as input to the next time-step withFeedback :: forall outputSize inputSize (w :: Typ) ss. KnownTyp w => KnownNat outputSize => KnownNat inputSize => RnnCell w ss (T '[inputSize+outputSize] w) (T '[outputSize] w) -> RnnCell w ('[outputSize] ': ss) (T '[inputSize ] w) (T '[outputSize] w) withFeedback cell x = C $ \(F prevoutputnVector :* s) -> let (s',y) = runC (cell (concat0 x prevoutputnVector)) s in (F y :* s',y) --------------------------------------------------------- -- RNN unfolding -- | Build a RNN by repeating a cell @n@ times. iterateCell :: ∀ n state input output b. (KnownNat n) => RnnCell b state input output -> Rnn n b state input output iterateCell c x = C $ \s -> chainForward (\(t,y) -> runC (c y) t) (s,x) -- | Build a RNN by repeating a cell @n@ times. However the state is -- propagated in the right-to-left direction (decreasing indices in -- the time dimension of the input and output tensors) iterateCellBackward :: ∀ n state input output b. (KnownNat n) => RnnCell b state input output -> Rnn n b state input output iterateCellBackward c x = C $ \s -> chainBackward (\(t,y) -> runC (c y) t) (s,x) -- | RNN helper chainForward :: ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (state , V n b) chainForward _ (s0 , VUnit) = (s0 , VUnit) chainForward f (s0 , x :** xs) = let (s1,x') = f (s0 , x) (sFin,xs') = chainForward f (s1 , xs) in (sFin,(x':**xs')) -- | RNN helper chainBackward :: ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (state , V n b) chainBackward _ (s0 , VUnit) = (s0 , VUnit) chainBackward f (s0 , (x:**xs)) = let (s1,xs') = chainBackward f (s0,xs) (sFin, x') = f (s1,x) in (sFin,(x':**xs')) -- | RNN helper chainForwardWithState :: ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (V n b, V n state) chainForwardWithState _ (_s0 , VUnit) = (VUnit, VUnit) chainForwardWithState f (s0 , (x:**xs)) = let (s1,x') = f (s0 , x) (xs',ss) = chainForwardWithState f (s1 , xs) in ((x':**xs'), (s1:**ss) ) -- -- | RNN helper -- chainBackwardWithState :: -- ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (state , V n b, V n state) -- chainBackwardWithState _ (s0 , VUnit) = return (s0 , VUnit, VUnit) -- chainBackwardWithState f (s0 , (x:**xs)) = do -- (s1,xs',ss') <- chainBackwardWithState f (s0,xs) -- (sFin, x') <- f (s1,x) -- return (sFin,(x':**xs'),(sFin:**ss')) -- | RNN helper transposeV :: forall n xs t. All KnownShape xs => KnownNat n => SList xs -> V n (HTV t xs) -> HTV t (Ap (FMap (Cons n)) xs) transposeV Unit _ = Unit transposeV (_ :* n) xxs = F ys' :* yys' where (ys,yys) = help @(Tail xs) xxs ys' = stack0 ys yys' = transposeV n yys help :: forall ys x tt. V n (HTV tt (x ': ys)) -> (V n (T x tt) , V n (HTV tt ys)) help (xs) = ((fmap (fromF . hhead) xs),(fmap htail xs)) -- | @(gatherFinalStates dynLen states)[i] = states[dynLen[i]-1]@ gatherFinalStates :: KnownShape x => KnownNat n => T '[] Int32 -> T (n ': x) t -> T x t gatherFinalStates dynLen states = gather states (dynLen ⊝ constant 1) gathers :: forall n xs t. All KnownShape xs => KnownNat n => SList xs -> T '[] Int32 -> HTV t (Ap (FMap (Cons n)) xs) -> HTV t xs gathers Unit _ Unit = Unit gathers (_ :* n) ixs (F x :* xs) = F (gatherFinalStates ixs x) :* gathers @n n ixs xs -- | @rnnWithCull dynLen@ constructs an RNN as normal, but returns the -- state after step @dynLen@ only. iterateWithCull :: forall n x y ls b. KnownLen ls => KnownNat n => All KnownShape ls => T '[] Int32 -- ^ dynamic length -> RnnCell b ls x y -> Rnn n b ls x y iterateWithCull dynLen cell xs = C $ \s0 -> let (us,ss) = chainForwardWithState (uncurry (flip (runC . cell))) (s0,xs) sss = transposeV @n (typeSList @ls) ss in (gathers @n (typeSList @ls) dynLen sss,us) -- -- | Like @rnnWithCull@, but states are threaded backwards. -- rnnBackwardsWithCull :: forall n bs x y ls b. -- KnownLen ls => KnownNat n => All KnownLen ls => All (LastEqual bs) ls => -- T '[bs] Int32 -> RnnCell b ls x y -> RNN n b ls x y -- rnnBackwardsWithCull dynLen cell (s0, t) = do -- (us,ss) <- chainBackwardWithState cell (s0,xs) -- let sss = transposeV @n (shapeSList @ls) ss -- return (gathers @n (shapeSList @ls) (n - dynLen) sss,us) ================================================ FILE: TypedFlow/Layers/RNN/Cells.hs ================================================ {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnicodeSyntax #-} {-| Module : TypedFlow.Layers.RNN.Cells Description : RNN cells Copyright : (c) Jean-Philippe Bernardy, 2017 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental -} module TypedFlow.Layers.RNN.Cells ( -- * RNN Cells cellInitializerBit, LSTMP(..), lstm, GRUP(..), gru, StackP(..), stackRU, ) where import TypedFlow.Layers.RNN.Base import TypedFlow.TF import TypedFlow.Types import TypedFlow.Types.Proofs import GHC.TypeLits import TypedFlow.Layers.Core (DenseP(..),(#)) import Prelude hiding (RealFrac(..)) -------------------------------------- -- Cells -- | Standard RNN gate initializer. (The recurrent kernel is -- orthogonal to avoid divergence; the input kernel is glorot) cellInitializerBit :: ∀ n x t. (KnownNat n, KnownNat x, KnownFloat t) => Gen (DenseP t (n + x) n) cellInitializerBit = DenseP <$> (concat0 <$> recurrentInitializer <*> kernelInitializer) <*> biasInitializer where recurrentInitializer :: Gen (Tensor '[n, n] t) recurrentInitializer = noise $ OrthogonalD kernelInitializer :: Gen (Tensor '[x, n] t) kernelInitializer = glorotUniform biasInitializer = pure zeros -- | Parameter for an LSTM data LSTMP t n x = LSTMP (DenseP t (n+x) n) (DenseP t (n+x) n) (DenseP t (n+x) n) (DenseP t (n+x) n) instance (KnownNat n, KnownNat x, KnownFloat t) => KnownTensors (LSTMP t n x) where travTensor f s (LSTMP x y z w) = LSTMP <$> travTensor f (s<>"_f") x <*> travTensor f (s<>"_i") y <*> travTensor f (s<>"_c") z <*> travTensor f (s<>"_o") w instance (KnownNat n, KnownNat x, KnownFloat t) => ParamWithDefault (LSTMP t n x) where defaultInitializer = LSTMP <$> forgetInit <*> cellInitializerBit <*> cellInitializerBit <*> cellInitializerBit where forgetInit = DenseP <$> (denseWeights <$> cellInitializerBit) <*> pure ones -- | Standard LSTM lstm :: ∀ n x t. KnownNat x => KnownNat n => KnownFloat t => LSTMP t n x -> RnnCell t '[ '[n], '[n]] (Tensor '[x] t) (Tensor '[n] t) lstm (LSTMP wf wi wc wo) input = C $ \(VecPair ht1 ct1) -> let f = sigmoid (wf # hx) hx = (concat0 ht1 input) i = sigmoid (wi # hx) cTilda = tanh (wc # hx) o = sigmoid (wo # hx) c = ((f ⊙ ct1) + (i ⊙ cTilda)) h = (o ⊙ tanh c) in (VecPair h c, h) -- | Parameter for a GRU data GRUP t n x = GRUP (T [n+x,n] t) (T [n+x,n] t) (T [n+x,n] t) instance (KnownNat n, KnownNat x, KnownFloat t) => KnownTensors (GRUP t n x) where travTensor f s (GRUP x y z) = GRUP <$> travTensor f (s<>"_z") x <*> travTensor f (s<>"_r") y <*> travTensor f (s<>"_w") z instance (KnownNat n, KnownNat x, KnownFloat t) => ParamWithDefault (GRUP t n x) where defaultInitializer = GRUP <$> (denseWeights <$> cellInitializerBit) <*> (denseWeights <$> cellInitializerBit) <*> (denseWeights <$> cellInitializerBit) -- | Standard GRU cell gru :: ∀ n x t. KnownNat x => (KnownNat n, KnownFloat t) => GRUP t n x -> RnnCell t '[ '[n] ] (Tensor '[x] t) (Tensor '[n] t) gru (GRUP wz wr w) xt = C $ \(VecSing ht1) -> let hx = (concat0 ht1 xt) zt = sigmoid (wz ∙ hx) rt = sigmoid (wr ∙ hx) hTilda = tanh (w ∙ (concat0 (rt ⊙ ht1) xt)) ht = ((ones ⊝ zt) ⊙ ht1 + zt ⊙ hTilda) in (VecSing ht, ht) data StackP w n = StackP (DenseP w (n + n) 3) defStackP :: KnownNat n => KnownFloat w => Gen (StackP w n) defStackP = StackP <$> defaultInitializer -- (DenseP glorotUniform (stack0 (V [zeros, constant (-1), zeros]) )) -- demote popping a bit instance (KnownNat n, KnownTyp w) => KnownTensors (StackP w n) where travTensor f s (StackP d) = StackP <$> travTensor f s d instance (KnownNat n, KnownFloat w) => (ParamWithDefault (StackP w n)) where defaultInitializer = defStackP -- | A stack recurrent unit. The input has two purposes: 1. it is -- saved in a stack. 2. it controls (a dense layer which gives) the -- operation to apply on the stack. The first type argument is the -- depth of the stack. stackRU :: ∀k n bs w. KnownNat k => KnownNat n => (KnownNat bs) => (KnownFloat w) => StackP w n -> RnnCell w '[ '[k+1,n]] (Tensor '[n] w) (Tensor '[n] w) stackRU (StackP w) input = C $ \(VecSing st1) -> succPos @k #> plusMono @k @1 #> plusComm @k @1 #> termCancelation @k @1 #> let ct1 = nth0' @0 st1 hx = concat0 ct1 input action :: T '[3] w action = softmax0 (w # hx) tl :: T '[k,n] w tl = slice0 @1 @(k+1) st1 it :: T '[k,n] w it = slice0 @0 @k st1 stTilda :: T '[3,k+1,n] w stTilda = stack0 (st1 :** (tl `concat0` zeros) :** (expandDim0 input `concat0` it) :** VUnit) st :: T '[k+1,n] w st = inflate2 (flatten12 stTilda ∙ action) ct = nth0' @0 st in (VecSing st, ct) ================================================ FILE: TypedFlow/Layers/RNN.hs ================================================ {-| Module : TypedFlow.Layers.RNN Description : RNN cells, layers and combinators. Copyright : (c) Jean-Philippe Bernardy, 2017 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental -} module TypedFlow.Layers.RNN ( module TypedFlow.Layers.RNN.Base, module TypedFlow.Layers.RNN.Cells, module TypedFlow.Layers.RNN.Attention) where import TypedFlow.Layers.RNN.Base import TypedFlow.Layers.RNN.Cells import TypedFlow.Layers.RNN.Attention ================================================ FILE: TypedFlow/Layers.hs ================================================ module TypedFlow.Layers (module TypedFlow.Layers.Core ,module TypedFlow.Layers.RNN ) where import TypedFlow.Layers.Core import TypedFlow.Layers.RNN ================================================ FILE: TypedFlow/Learn.hs ================================================ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE PatternSynonyms #-} {-| Module : TypedFlow.Learn Description : Loss functions and optimization strategies Copyright : (c) Jean-Philippe Bernardy, 2017 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental -} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ApplicativeDo #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UndecidableSuperClasses #-} {-# LANGUAGE UnicodeSyntax #-} module TypedFlow.Learn (-- losses: sparseCategorical, binary, timedCategorical, categoricalDistribution,sparseCategoricalDensePredictions, -- types Options(..), defaultOptions, Function(..),Model,ModelOutput, PreparedFunction(..), PreparedModel(..), -- other simpleModel, modelFunction, probeFunction, addRegularizer, prepare, -- utils placeholderName, ) where import Data.Proxy import TypedFlow.Types import TypedFlow.Types.Proofs (knownAppend, (?>), ) import TypedFlow.Broadcast (doBroadcast, mapPlaceHolders, ConsSh,doBroadcastSingle) import TypedFlow.Abstract (doExtractVars) import TypedFlow.TF import Prelude hiding (RealFrac(..)) import GHC.TypeLits -- | Triple of values that are always output in a model: prediction, loss and accuracy. -- @t@ is the type of the prediction. -- @s@ is the shape of the loss and accuracy type ModelOutput t predictionShape s = Placeholders '[ '("loss",s,Float32) -- loss associated with the prediction , '("accuracy",s,Float32) -- is the prediction correct? , '("y_",s++predictionShape,t) -- prediction (which can contain prediction-shaped info) ] pattern ModelOutput :: T (s++predictionShape) t -> T s Float32 -> T s Float32 -> ModelOutput t predictionShape s pattern ModelOutput y loss accur = PHT loss :* PHT accur :* PHT y :* Unit -- | A standard modelling function: (input value, gold value) ↦ (prediction, accuracy, loss). -- input is the shape of the input. -- output is the shape of the output (one element per individual loss and accuracy) -- p is the shape of each output element. -- g is the shape of each gold output --- often equal to p. type Model input tIn g p output tOut = T input tIn -> T (g++output) tOut -> ModelOutput tOut p output -- | First type argument is the number of classes. @categorical -- logits gold@ return (prediction, accuraccy, loss) sparseCategorical :: forall nCat. KnownNat nCat => Model '[nCat] Float32 '[] '[] '[] Int32 sparseCategorical logits y = let y_ = argmax0 logits modelCorrect = cast (equal y_ y) modelLoss = sparseSoftmaxCrossEntropyWithLogits y logits in ModelOutput y_ modelLoss modelCorrect -- | First type argument is the number of classes. @categorical -- logits gold@ return (prediction, accuracy, loss) sparseCategoricalDensePredictions :: forall nCat. KnownNat nCat => Tensor '[nCat] Float32 -> Tensor '[] Int32 -> ModelOutput Float32 '[nCat] '[] sparseCategoricalDensePredictions logits y = let y_ :: T '[nCat] Float32 y_ = softmax0 logits modelCorrect = cast (equal (argmax0 logits) y) modelLoss = sparseSoftmaxCrossEntropyWithLogits y logits in ModelOutput y_ modelLoss modelCorrect -- | First type argument is the number of classes. -- @categoricalDistribution logits gold@ return (prediction, -- accuraccy, loss) accuracy is reported as predicting the same class -- as the input 'winning' class. categoricalDistribution :: forall nCat. KnownNat nCat => Model '[nCat] Float32 '[nCat] '[nCat] '[] Float32 categoricalDistribution logits y = ModelOutput (softmax0 logits) (softmaxCrossEntropyWithLogits y logits) (cast (equal (argmax0 @'B32 logits) (argmax0 y))) -- | @timedCategorical targetWeights logits y@ -- -- targetWeights: a zero-one matrix of the same size as -- decoder_outputs. It is intended to mask padding positions outside -- of the target sequence lengths with values 0. -- -- Note that the accuracy is computed by multiplying the accuracies at -- individual time steps with the targetWeights. timedCategorical :: forall len nCat bits. KnownNat nCat => KnownNat len => KnownBits bits => Tensor '[len] (Flt bits) -> Tensor '[len,nCat] (Flt bits) -> Tensor '[len] Int32 -> ModelOutput (Flt bits) '[len,nCat] '[] timedCategorical targetWeights logits y = let y_ :: Tensor '[len] Int32 y_ = argmax1 logits modelY = softmax1 logits -- correct prediction for each position correctPrediction :: Tensor '[len] TFBool correctPrediction = equal y_ y -- total number of correct predictions correctPredictionWeighted :: Tensor '[] (Flt bits) correctPredictionWeighted = reduceSumAll (cast @(Flt bits) correctPrediction ⊙ targetWeights) weightSum = reduceSumAll targetWeights modelCorrect :: Tensor '[] Float32 modelCorrect = cast (correctPredictionWeighted / weightSum) crossEntropies = zipWithT sparseSoftmaxCrossEntropyWithLogits y logits modelLoss = cast @Float32 (reduceSumAll (crossEntropies ⊙ targetWeights) / weightSum) in ModelOutput modelY modelLoss modelCorrect -- | Model with @n@ binary outputs. binary :: KnownNat n => Model '[n] Float32 '[] '[] '[n] Int32 binary logits y = let y_ = cast @Int32 (round sigy_) sigy_ = sigmoid logits in ModelOutput (y_) (sigmoidCrossEntropyWithLogits (cast @Float32 y) logits) (cast (equal y_ y)) -- | Model compiler options data Options = Options {maxGradientNorm :: Maybe Prelude.Float -- ^ apply gradient clipping } -- | default model compiler options defaultOptions :: Options defaultOptions = Options {maxGradientNorm = Nothing} type family Concatenate xs where Concatenate (x ': xs) = x ++ Concatenate xs Concatenate '[] = '[] genPlaceholders :: All KnownPlaceholder shapesAndTypes => SList shapesAndTypes -> Placeholders shapesAndTypes genPlaceholders Unit = Unit genPlaceholders (ph :* names) = PHT (T (ExternalVar (Ref (placeholderName ph) typeSShape typeSTyp))) :* genPlaceholders names placeholderName :: forall (ph :: PH) p. KnownPlaceholder ph => p ph -> String placeholderName proxy = refName (placeHolderRef proxy) simpleModel :: forall p sx tx sy ty sy_ ty_. (KnownShape sy_, KnownShape p, KnownShape sx, KnownTyp ty_, KnownShape sy, KnownTyp tx, KnownTyp ty) => (Tensor sx tx -> Tensor sy ty -> ModelOutput ty_ p sy_) -> Function simpleModel f = knownAppend @sy_ @p ?> modelFunction "runModel" f' where f' :: Placeholders '[ '("x",sx,tx), '("y",sy,ty)] -> ModelOutput ty_ p sy_ f' (PHT x :* PHT y :* Unit) = f x y -- | Add a term to the loss. This function is intendend to add -- regularizers, ie. losses that do not depend on the predicted -- output, but rather on the structure of a parameter. addRegularizer :: Scalar Float32 -> Gen () addRegularizer r = GPState $ \GState{..} -> ((),GState{genRegularizers=r:genRegularizers,..}) knownBatchModel :: forall n ps. KnownNat n => NP (Sat KnownPlaceholder) ps -> NP (Sat KnownPlaceholder) (Ap (FMap (ConsSh n)) ps) knownBatchModel Unit = Unit knownBatchModel (Comp Dict :* xs) = Sat :* knownBatchModel @n xs -- | take the mean of loss/accur over the batch, etc. and add regulariser to loss consolidate :: forall s rest. KnownShape s => Scalar Float32 -> Placeholders ( '("loss",s ,Float32) ': '("accuracy",s ,Float32) ': rest) -> Placeholders ( '("loss",'[],Float32) ': '("accuracy",'[],Float32) ': rest) consolidate extraLoss (PHT loss :* PHT accur :* rest) = (PHT (reduceMeanAll loss + extraLoss) :* PHT (reduceMeanAll accur) :* rest) class (All KnownPlaceholder ps, KnownLen ps) => KnownPHS ps instance (All KnownPlaceholder ps, KnownLen ps) => KnownPHS ps data PreparedFunction = PreparedFunction {pfName :: String, pfBatched :: Bool, pfInputs, pfOutputs :: SomeSuch KnownPHS Placeholders} data PreparedModel = PreparedModel {pmBatchSize :: Integer, pmParams :: [VarInfo], pmFunctions :: [PreparedFunction] } -- | Prepare compilation of a model by: -- extracting and exposing parameters -- batching the model -- exposing placeholders -- consolidating loss and accuracy -- adding regularizers to the loss prepare :: forall bs. (KnownNat bs) => Gen [Function] -> PreparedModel prepare fGen = PreparedModel {pmBatchSize = natVal (Proxy @bs) ,pmParams = [VarInfo{varInitial=fmap doBroadcastSingle varInitial,..} | VarInfo{..} <- filter varTrainable vars] ,pmFunctions = flip map fs $ \case ModelFn nm st1 st2 f -> knownAll (knownBatchModel @bs st1) $ knownAll (knownBatchModel @bs st2) $ knownAll st1 $ knownAll st2 $ let placeHolders = genPlaceholders typeSList u = -777 -- magic unique identifier for the batch dimension in PreparedFunction nm True (SomeSuch placeHolders) (SomeSuch $ doBroadcast (consolidate {-@(bs ': s) @(BPH bs st2)-} regular (mapPlaceHolders @bs u True f placeHolders))) ProbeFn nm st1 st2 f -> knownAll st1 $ knownAll st2 $ let placeHolders = genPlaceholders typeSList in PreparedFunction nm False (SomeSuch placeHolders) (SomeSuch (doBroadcast (f placeHolders))) } where (fs,finalState,vars) = doExtractVars fGen regular = sum (genRegularizers finalState) data Function where ModelFn :: (KnownShape s, KnownLen st1, KnownLen st2) => String -> NP (Sat KnownPlaceholder) st1 -> NP (Sat KnownPlaceholder) st2 -> (Placeholders st1 -> Placeholders ('("loss",s,Float32) ': '("accuracy",s,Float32) ': st2)) -> Function ProbeFn :: (KnownLen st1, KnownLen st2, All KnownPlaceholder st1, All KnownPlaceholder st2) => String -> NP (Sat KnownPlaceholder) st1 -> NP (Sat KnownPlaceholder) st2 -> (Placeholders st1 -> Placeholders st2) -> Function modelFunction :: (KnownShape s, KnownLen st1, KnownLen st2, All KnownPlaceholder st1, All KnownPlaceholder st2) => String -> (Placeholders st1 -> Placeholders ('("loss",s,Float32) ': '("accuracy",s,Float32) ': st2)) -> Function modelFunction nm f = ModelFn nm (allKnown @KnownPlaceholder) (allKnown @KnownPlaceholder) f probeFunction :: (KnownLen st1, KnownLen st2, All KnownPlaceholder st1, All KnownPlaceholder st2) => String -> (Placeholders st1 -> Placeholders st2) -> Function probeFunction nm f = ProbeFn nm (allKnown @KnownPlaceholder) (allKnown @KnownPlaceholder) f ================================================ FILE: TypedFlow/Memo.hs ================================================ {-# LANGUAGE TypeInType #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE GADTs #-} module TypedFlow.Memo where import qualified Data.IntMap as I import qualified Data.Map.Strict as M import System.Mem.StableName import Data.IORef import System.IO.Unsafe import Unsafe.Coerce import Data.Kind (Type) type SNMap k v = I.IntMap [(StableName k,v)] snMapLookup :: StableName k -> SNMap k v -> Maybe v snMapLookup sn m = do x <- I.lookup (hashStableName sn) m lookup sn x snMapInsert :: StableName k -> v -> SNMap k v -> SNMap k v snMapInsert sn res = I.insertWith (++) (hashStableName sn) [(sn,res)] memo :: (a -> b) -> a -> b memo f = unsafePerformIO ( do { tref <- newIORef (I.empty) ; return (applyStable f tref) }) applyStable :: (a -> b) -> IORef (SNMap a b) -> a -> b applyStable f tbl arg = unsafePerformIO ( do { sn <- makeStableName arg ; lkp <- snMapLookup sn <$> readIORef tbl ; case lkp of Just result -> return result Nothing -> do { let res = f arg ; modifyIORef tbl (snMapInsert sn res) ; return res }}) memoOrd :: Ord a => (a -> b) -> a -> b memoOrd f = unsafePerformIO ( do { tref <- newIORef (M.empty) ; return (applyStableOrd f tref) }) applyStableOrd :: Ord a => (a -> b) -> IORef (M.Map a b) -> a -> b applyStableOrd f tbl arg = unsafePerformIO ( do { lkp <- M.lookup arg <$> readIORef tbl ; case lkp of Just result -> return result Nothing -> do { let res = f arg ; modifyIORef tbl (M.insert arg res) ; return res }}) data Some2 k1 k2 (f :: k1 -> k2 -> Type) where Some2 :: forall k1 k2 f a b. StableName (f a b) -> Some2 k1 k2 f instance Eq (Some2 k1 k2 f) where Some2 sn1 == Some2 sn2 = eqStableName sn1 sn2 type SSNMap2 k1 k2 (f :: k1 -> k2 -> Type) v = I.IntMap [(Some2 k1 k2 f,v)] makeSn2 :: f a b -> Some2 k1 k2 f makeSn2 = Some2 . unsafePerformIO . makeStableName snMapLookup2 :: Some2 k1 k2 f -> SSNMap2 k1 k2 f v -> Maybe v snMapLookup2 (Some2 sn) m = do x <- I.lookup (hashStableName sn) m lookup (Some2 sn) x snMapInsert2 :: Some2 k1 k2 f -> v -> SSNMap2 k1 k2 f v -> SSNMap2 k1 k2 f v snMapInsert2 (Some2 sn) res = I.insertWith (++) (hashStableName sn) [(Some2 sn,res)] data KV k1 k2 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) where KV :: forall k1 k2 f v a b. StableName (f a b) -> v a b -> KV k1 k2 f v type SNMap22 k1 k2 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) = I.IntMap [KV k1 k2 f v] snMap22Lookup :: StableName (f a b) -> SNMap22 k1 k2 f v -> Maybe (v a b) snMap22Lookup sn m = do x <- I.lookup (hashStableName sn) m lkKV sn x lkKV :: StableName (f a b) -> [KV k1 k2 f v] -> Maybe (v a b) lkKV _ [] = Nothing lkKV sn (KV sn' v:kvs) | eqStableName sn sn' = Just (unsafeCoerce v) -- sn == sn' -> a == a' and b == b' | otherwise = lkKV sn kvs snMap22Insert :: KV k1 k2 f v -> SNMap22 k1 k2 f v -> SNMap22 k1 k2 f v snMap22Insert (KV sn res) = I.insertWith (++) (hashStableName sn) [KV sn res] -- | The type of a memo table for functions of a. type Memo a = forall r. (a -> r) -> (a -> r) -- | Memoize a two argument function (just apply the table directly for -- single argument functions). memo2 :: Memo a -> Memo b -> (a -> b -> r) -> (a -> b -> r) memo2 a b = a . (b .) -- | Memoize a three argument function. memo3 :: Memo a -> Memo b -> Memo c -> (a -> b -> c -> r) -> (a -> b -> c -> r) memo3 a b c = a . (memo2 b c .) ================================================ FILE: TypedFlow/Memo2.hs ================================================ {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE GADTs #-} module TypedFlow.Memo2 where import Data.Kind (Type) import qualified Data.Map.Strict as M import System.Mem.StableName -- import Data.IORef -- import System.IO.Unsafe import Unsafe.Coerce import qualified Data.IntMap as I import Data.Type.Equality import Control.Monad.IO.Class import Data.IORef import TypedFlow.Types.Proofs (SingEq(..)) import Data.List (intercalate) data Map0 k (m :: Type -> Type) f v = forall . Map0 { m0Key :: f -> IO k, m0Empty :: m v, m0lk :: k -> m v -> Maybe v, m0upd :: k -> (Maybe v -> v) -> m v -> m v, m0fmap :: forall u w. (u -> w) -> m u -> m w, m0showKey :: k -> String, m0showTbl :: (v -> String) -> (m v -> String) } data Map1 (k :: k1 -> Type) (m :: (k1 -> Type) -> Type) (f :: k1 -> Type) (v :: k1 -> Type) = Map1 { m1Key :: forall x. f x -> IO (k x), m1Empty :: m v, m1lk :: forall x. k x -> m v -> Maybe (v x), m1upd :: forall x. k x -> (Maybe (v x) -> (v x)) -> m v -> m v, m1showKey :: forall x . k x -> String, m1showTbl :: (forall x . v x -> String) -> (m v -> String) } data Map2 (k :: k1 -> k2 -> Type) (m :: (k1 -> k2 -> Type) -> Type) (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) = Map2 { m2Key :: forall x y. f x y -> IO (k x y), m2Empty :: m v, m2lk :: forall x y. k x y -> m v -> Maybe (v x y), m2upd :: forall x y. k x y -> (Maybe (v x y) -> (v x y)) -> m v -> m v, -- m2fmap :: forall u w. (forall x y. u x y -> w x y) -> m u -> m w, m2showKey :: forall x y. k x y -> String, m2showTbl :: (forall x y. v x y -> String) -> (m v -> String) } data Map3 (k :: k1 -> k2 -> k3 -> Type) (m :: (k1 -> k2 -> k3 -> Type) -> Type) (f :: k1 -> k2 -> k3 -> Type) (v :: k1 -> k2 -> k3 -> Type) = Map3 { m3Key :: forall x y z. f x y z -> IO (k x y z), m3Empty :: m v, m3lk :: forall x y z. k x y z -> m v -> Maybe (v x y z), m3upd :: forall x y z. k x y z -> (Maybe (v x y z) -> (v x y z)) -> m v -> m v, m3showKey :: forall x y z. k x y z -> String, m3showTbl :: (forall x y z. v x y z -> String) -> (m v -> String) } newtype Id x = Id x ordMap :: forall k b. (Ord k, Show k) => Map0 k (M.Map k) k b ordMap = Map0 {..} where m0Key = return m0Empty = mempty m0lk k = M.lookup k m0upd k f m = M.alter (Just . f) k m m0fmap = fmap m0showKey = show m0showTbl sh m = intercalate ";" [(show k) <> "↦" <> (sh v) | (k,v) <- M.assocs m] data Single1 f g where None1 :: Single1 f g Single1 :: f a -> g a -> Single1 f g verifMap1 :: forall k v. SingEq k => Map1 k (Single1 k) k v verifMap1 = Map1 {..} where m1Key = return m1Empty = None1 m1lk :: k a -> Single1 k b -> Maybe (b a) m1lk k = \case None1 -> Nothing Single1 k' v -> case testEq k k' of Just Refl -> Just v Nothing -> error "verifMap1: mismatching keys! (1)" m1upd :: forall x. k x -> (Maybe (v x) -> (v x)) -> Single1 k v -> Single1 k v m1upd k f None1 = Single1 k (f Nothing) m1upd k f (Single1 k' v) = case testEq k k' of Just Refl -> Single1 k (f (Just v)) Nothing -> error "verifMap1: mismatching keys! (2)" m1showKey _ = "#" m1showTbl :: (forall x . v x -> String) -> (Single1 k v -> String) m1showTbl _ None1 = "·" m1showTbl h (Single1 _ v) = "!" <> (h v) testStable :: StableName a -> StableName b -> Maybe (a :~: b) testStable sn sn' | eqStableName sn sn' = Just (unsafeCoerce Refl) | otherwise = Nothing snMap2 :: forall f v. Map2 (SN2 f) (SNMap22 f) f v snMap2 = Map2 {..} where m2showTbl :: (forall x y. v x y -> String) -> (SNMap22 f v -> String) m2showTbl h (SNMap22 m) = intercalate "," [ m2showKey k <> "↦" <> h v | e <- I.elems m, KV k v <- e ] m2showKey (SN2 sn) = show (hashStableName sn) m2Key obj = SN2 <$> makeStableName obj m2Empty = mempty m2lk = snMap22Lookup m2upd :: SN2 f x y -> (Maybe (v x y) -> (v x y)) -> SNMap22 f v -> SNMap22 f v m2upd (SN2 sn) f (SNMap22 m) = SNMap22 $ I.alter (\case Nothing -> Just [KV (SN2 sn) (f Nothing)] Just p -> Just (updKV (SN2 sn) f p)) (hashStableName sn) m updKV :: SN2 f' x y -> (Maybe (v' x y) -> (v' x y)) -> [KV k1 k2 (SN2 f') v'] -> [KV k1 k2 (SN2 f') v'] updKV (SN2 sn) f [] = [KV (SN2 sn) (f Nothing)] updKV (SN2 sn) f (v@(KV (SN2 sn') x):xs) = case testStable sn sn' of Just Refl -> KV (SN2 sn') (f (Just x)):xs Nothing -> v : updKV (SN2 sn) f xs -- m2fmap :: forall u w. (forall x y. u x y -> w x y) -> SNMap22 f u -> SNMap22 f w -- m2fmap h (SNMap22 t) = SNMap22 (fmap (fmap (\(KV k v) -> KV k (h v))) t) snMap22Lookup :: forall a b f' v'. SN2 f' a b -> SNMap22 f' v' -> Maybe (v' a b) snMap22Lookup (SN2 sn) (SNMap22 m) = do x <- I.lookup (hashStableName sn) m lkKV sn x lkKV :: forall k1 k2 f' v' a b . StableName (f' a b) -> [KV k1 k2 (SN2 f') v'] -> Maybe (v' a b) lkKV _ [] = Nothing lkKV sn (KV (SN2 sn') v:kvs) = case testStable sn sn' of Just Refl -> Just (unsafeCoerce v) -- sn == sn' -> a == a' and b == b' Nothing -> lkKV sn kvs data KV k1 k2 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) where KV :: forall k1 k2 f v a b. f a b -> v a b -> KV k1 k2 f v newtype SNMap22 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) = SNMap22 (I.IntMap [KV k1 k2 (SN2 f) v]) deriving (Monoid, Semigroup) newtype SN2 (f :: k1 -> k2 -> Type) a b = SN2 (StableName (f a b)) data (:.:) (m1 :: k2 -> Type) (m2 :: k1 -> k2) (h :: k1) = Comp (m1 (m2 h)) data Sig02 f g x y where Ex02 :: f -> g x y -> Sig02 f g x y data Sig03 f g x y z where Ex03 :: f -> g x y z -> Sig03 f g x y z data Sig12 f g x y z where Ex12 :: f x -> g y z -> Sig12 f g x y z data Sig22 f g x y where Ex22 :: f x y -> g x y -> Sig22 f g x y data P33 f g x y z where T33 :: f x y z -> g x y z -> P33 f g x y z containing00 :: (forall v. Map0 k1 m1 f v) -> Map0 k2 m2 g h -> Map0 (k1,k2) (m1 :.: m2) (f,g) h containing00 f g = Map0 { m0Key = (\(a,b) -> (,) <$> m0Key f a <*> m0Key g b), m0Empty = Comp (m0Empty f), m0lk = \(k1,k2) (Comp t) -> do t' <- m0lk f k1 t; m0lk g k2 t', m0upd = \(k1,k2) h (Comp t) -> Comp $ m0upd f k1 (m0upd g k2 h . \case Just tb -> tb; Nothing -> (m0Empty g)) t, m0fmap = \h (Comp t) -> Comp $ m0fmap f (m0fmap g h) t, m0showKey = \(k1,k0) -> m0showKey f k1 <> "," <> m0showKey g k0, m0showTbl = \h (Comp t) -> m0showTbl f (m0showTbl g h) t } containing02 :: (forall v. Map0 k1 m1 f v) -> Map2 k2 m2 g h -> Map2 (Sig02 k1 k2) (m1 :.: m2) (Sig02 f g) h containing02 f g = Map2 { m2Key = (\(Ex02 a b) -> Ex02 <$> m0Key f a <*> m2Key g b), m2Empty = Comp (m0Empty f), m2lk = \(Ex02 k1 k2) (Comp t) -> do t' <- m0lk f k1 t; m2lk g k2 t', m2upd = \(Ex02 k1 k2) h (Comp t) -> Comp $ m0upd f k1 (m2upd g k2 h . \case Just tb -> tb; Nothing -> (m2Empty g)) t, -- m2fmap = \h (Comp t) -> Comp $ m0fmap f (m2fmap g h) t, m2showKey = \(Ex02 k1 k2) -> m0showKey f k1 <> "," <> m2showKey g k2, m2showTbl = \h (Comp t) -> m0showTbl f (m2showTbl g h) t } containing03 :: (forall v. Map0 k1 m1 f v) -> Map3 k2 m2 g h -> Map3 (Sig03 k1 k2) (m1 :.: m2) (Sig03 f g) h containing03 f g = Map3 { m3Key = (\(Ex03 a b) -> Ex03 <$> m0Key f a <*> m3Key g b), m3Empty = Comp (m0Empty f), m3lk = \(Ex03 k1 k3) (Comp t) -> do t' <- m0lk f k1 t; m3lk g k3 t', m3upd = \(Ex03 k1 k3) h (Comp t) -> Comp $ m0upd f k1 (m3upd g k3 h . \case Just tb -> tb; Nothing -> (m3Empty g)) t, m3showKey = \(Ex03 k1 k2) -> m0showKey f k1 <> "," <> m3showKey g k2 , m3showTbl = \h (Comp t) -> m0showTbl f (m3showTbl g h) t } newtype Lam' (m2 :: (k2 -> k3 -> Type) -> Type) (h :: k1 -> k2 -> k3 -> Type) (a :: k1) = Lam' {fromLam' :: (m2 (h a))} data M12 (m1 :: (k1 -> Type) -> Type) (m2 :: (k2 -> k3 -> Type) -> Type) (h :: k1 -> k2 -> k3 -> Type) = M12 (m1 (Lam' m2 h)) containing12 :: (forall v. Map1 k1 m1 f v) -> (forall k4. Map2 k2 m2 g (h k4)) -> Map3 (Sig12 k1 k2) (M12 m1 m2) (Sig12 f g) h containing12 f g = Map3 { m3Key = (\(Ex12 a b) -> Ex12 <$> m1Key f a <*> m2Key g b), m3Empty = M12 (m1Empty f), m3lk = \(Ex12 k1 k2) (M12 t) -> do Lam' t' <- m1lk f k1 t; m2lk g k2 t', m3upd = \(Ex12 k1 k2) h (M12 t) -> M12 $ m1upd f k1 (Lam' . m2upd g k2 h . (\case Just tb -> tb; Nothing -> m2Empty g) . fmap fromLam') t, m3showKey = \(Ex12 k1 k2) -> m1showKey f k1 <> ">" <> m2showKey g k2, m3showTbl = \h (M12 t) -> m1showTbl f (m2showTbl g h . fromLam') t } data F2m m g h = F2m (forall x y. g x y -> m (h x y)) data F2m' m g f h = F2m' (forall x y. g x y -> f x y -> m (h x y)) data F3m m g h = F3m (forall x y z. g x y z -> m (h x y z)) data F3m' m g f h = F3m' (forall x y z. g x y z -> f x y z -> m (h x y z)) memo2 :: forall g h k m n. MonadIO n => Map2 k m g h -> ((forall x y. g x y -> n (h x y)) -> forall x y. g x y -> n (h x y)) -> n (F2m n g h) memo2 Map2{..} f = do tblRef <- liftIO $ newIORef m2Empty let finished :: forall x y. g x y -> n (h x y) finished arg = do tbl <- liftIO $ readIORef tblRef key <- liftIO $ m2Key arg case m2lk key tbl of Just result -> do -- liftIO $ putStrLn "memo2: hit" return result Nothing -> do -- liftIO $ putStrLn "memo2: miss" res <- f finished arg liftIO $ modifyIORef tblRef (m2upd key $ \_ -> res) return res return (F2m finished) memo2' :: forall g f h k m n. MonadIO n => Map2 k m g h -> ((forall x y. g x y -> f x y -> n (h x y)) -> forall x y. g x y -> f x y -> n (h x y)) -> n (F2m' n g f h) memo2' Map2{..} f = do tblRef <- liftIO $ newIORef m2Empty let finished :: forall x y. g x y -> f x y -> n (h x y) finished arg extra = do tbl <- liftIO $ readIORef tblRef key <- liftIO $ m2Key arg case m2lk key tbl of Just result -> do -- liftIO $ putStrLn "memo2': hit" return result Nothing -> do -- liftIO $ putStrLn ("memo2: miss " <> m2showKey key) -- <> " from " <> m3showTbl (const ".") tbl res <- f finished arg extra liftIO $ modifyIORef tblRef (m2upd key $ \_ -> res) return res return (F2m' finished) memo3' :: forall g f h k m n. MonadIO n => Map3 k m g h -> ((forall x y z. g x y z -> f x y z -> n (h x y z)) -> forall x y z. g x y z -> f x y z -> n (h x y z)) -> n (F3m' n g f h) memo3' Map3{..} f = do tblRef <- liftIO $ newIORef m3Empty let finished :: forall x y z. g x y z -> f x y z -> n (h x y z) finished arg extra = do tbl <- liftIO $ readIORef tblRef key <- liftIO $ m3Key arg case m3lk key tbl of Just result -> do -- liftIO $ putStrLn "memo3: hit" return result Nothing -> do -- liftIO $ putStrLn ("memo3: miss " <> m3showKey key) -- <> " from " <> m3showTbl (const ".") tbl res <- f finished arg extra liftIO $ modifyIORef tblRef (m3upd key $ \_ -> res) return res return (F3m' finished) ================================================ FILE: TypedFlow/Models/Topic.hs ================================================ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UnicodeSyntax #-} {-| Module : TypedFlow.Models.Topic Description : Topic models Copyright : (c) Jean-Philippe Bernardy, 2017 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental -} module TypedFlow.Models.Topic where import Prelude hiding (RealFrac(..)) import TypedFlow.TF import TypedFlow.Layers import TypedFlow.Types import TypedFlow.Types.Proofs ((?>), knownSum') import TypedFlow.Learn import GHC.TypeLits import Data.Monoid ((<>)) import Data.Proxy -- | A convolutional document summary function. Described in -- 'Topically Driven Neural Language Model' by Lau, Baldwin and Cohn. tdlmDocsummary :: forall (vocSize :: Nat) -- number of words (e :: Nat) -- size of the embedding (a :: Nat) -- number of features of the document vector summary (n :: Nat) -- length of the document (filterSize :: Nat) -- size of the convolution filter (t :: NBits) -- size of floats . KnownNat vocSize => KnownNat filterSize => KnownNat e => KnownNat a => KnownNat n => KnownBits t => (EmbeddingP vocSize e (Flt t)) -> (ConvP (Flt t) a e '[filterSize]) -> DropProb -> Gen (T '[n] Int32 -> T '[a] (Flt t)) tdlmDocsummary embs filters dropProb = do drpEmb <- mkDropout dropProb return $ \document -> let embeddedDoc :: Tensor [n,e] (Flt t) embeddedDoc = mapT (drpEmb . embedding @e @vocSize embs) document in reduceMax axis0 (conv' @'[n] filters embeddedDoc) tdlmDocsummary' :: forall (vocSize :: Nat) -- number of words (e :: Nat) -- size of the embedding (n :: Nat) -- length of the document -- (a :: Nat) -- number of features of the document vector summary -- (filterSize :: Nat) -- size of the convolution filter spec (t :: NBits) -- size of floats proxy . KnownNat vocSize => KnownNat (Ap Frst' spec) => KnownNat e => KnownNat (Ap Scnd' spec) => KnownNat n => KnownBits t => proxy spec -> (EmbeddingP vocSize e (Flt t)) -> (ConvP (Flt t) (Ap Scnd' spec) e '[(Ap Frst' spec)]) -> DropProb -> Gen (T '[n] Int32 -> T '[Ap Scnd' spec] (Flt t)) tdlmDocsummary' _proxy = tdlmDocsummary scnds :: SList xs -> SList (Ap (FMap Scnd') xs) scnds Unit = Unit scnds (_ :* xs) = Proxy :* scnds xs -- hmap _ Unit = Unit -- hmap f (x :* xs) = f x :* hmap f xs mkTdlmDocsummary :: forall (vocSize :: Nat) -- number of words (e :: Nat) -- size of the embedding (spec :: [(Nat,Nat)]) -- (size of the convolution filter,number of features) (n :: Nat) -- length of the document (t :: NBits) -- size of floats . KnownNat vocSize => KnownNat e => KnownNat n => KnownBits t => All KnownNat (Ap (FMap Scnd') spec) => All KnownNat (Ap (FMap Frst') spec) => SList spec -> DropProb -> Gen (T '[n] Int32 -> T '[Sum (Ap (FMap Scnd') spec)] (Flt t)) mkTdlmDocsummary xs0 dropProb = case xs0 of Unit -> return (\_ -> zeros) (proxy :* xs) -> knownSum' (scnds xs) ?> do embs <- parameterDefault ("embs_topic_" ++ show (sListLength xs)) filters <- parameterDefault ("filters_topic_" ++ show (sListLength xs)) f <- tdlmDocsummary' @vocSize @e proxy embs filters dropProb fs <- mkTdlmDocsummary @vocSize @e xs dropProb return $ \input -> concat0 (f input) (fs input) -- | Parameter for topics. This is effectively map from document -- features (a) to topic representations (vectors of size b) via k -- topic distributions. data TopicP t a k b = TopicP {topicDistributions :: (T '[a,k] (Flt t)) -- ^ a linear map from documents features (a) to topic distributions (k) ,topicRepresentations :: (T '[k,b] (Flt t)) -- ^ a linear map from topic distributions (k) to topic representations (b) } instance (KnownNat a, KnownNat k, KnownNat b, KnownBits t) => KnownTensors (TopicP t a k b) where travTensor f s (TopicP x y) = TopicP <$> travTensor f (s<>"_A") x <*> travTensor f (s<>"_B") y instance (KnownNat a, KnownNat k, KnownNat b, KnownBits t) => ParamWithDefault (TopicP t a k b) where defaultInitializer = TopicP <$> glorotUniform <*> glorotUniform -- | A topic modeler. Described 'Topically Driven Neural Language -- Model' by Lau, Baldwin and Cohn. Returns a function converting raw -- representations (eg. document summaries) to topic representations. -- This representation can be used as input to a dense layer to -- predict a word, or as input to an LSTM (initial state) to predict -- sentences. mkTdlmTopic :: forall (kk :: Nat) -- number of topics (a :: Nat) -- document vector summary size (b :: Nat) -- topic representation size (t :: NBits) -- size of floats . KnownNat kk => KnownNat a => KnownNat b => KnownBits t => Float -> TopicP t a kk b -> Gen (T '[a] (Flt t) -> (Tensor '[b] (Flt t), Tensor '[kk] (Flt t))) mkTdlmTopic separationConstant (TopicP topicInput topicOutput) = do drpS <- mkDropout (DropProb 0.1) let topicNormalized :: T '[kk,b] (Flt t) topicNormalized = mapT normalize topicOutput -- matrix of correlation between the topics topicCorrelation :: T '[kk,kk] (Flt t) topicCorrelation = matmul topicNormalized (transpose01 topicNormalized) -- max correlation between two distinct topics topicOverlap = reduceMaxAll (square (topicCorrelation ⊝ eye)) addRegularizer (constant separationConstant ⊙ cast topicOverlap) -- regularizer which ensures that topics are disjoint return (\d -> let p :: T '[kk] (Flt t) p = softmax0 (topicInput ∙ d) -- attention distribution (among the topics) in (drpS (topicOutput ∙ p), p)) -- | Gating unit which can be used to mix a RNN hidden state with an -- external information source (eg. topic representation). Described -- 'Topically Driven Neural Language Model' by Lau, Baldwin and Cohn; -- formula (3) tdlmGatingUnit :: KnownNat n => KnownFloat t => KnownNat m => (GRUP t m n) -> T '[n] t -> T '[m] t -> (T '[m] t) tdlmGatingUnit (GRUP wz wr w) s h = let x = concat0 h s z = sigmoid (wz ∙ x) r = sigmoid (wr ∙ x) hTilda = tanh (w ∙ (concat0 (r ⊙ h) s)) in ((ones ⊝ z) ⊙ h + z ⊙ hTilda) ================================================ FILE: TypedFlow/Models/Transformer.hs ================================================ {-# LANGUAGE PartialTypeSignatures #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UnicodeSyntax #-} {-# LANGUAGE NoStarIsType #-} {-| Module : TypedFlow.Models.Transformer Description : Topic models Copyright : (c) Jean-Philippe Bernardy, 2020 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental -} module TypedFlow.Models.Transformer where import Prelude hiding (RealFrac(..)) import TypedFlow.TF import TypedFlow.Abstract import TypedFlow.Layers import TypedFlow.Types import TypedFlow.Types.Proofs ((?>), knownSum') import GHC.TypeLits -- Convention for type variables: -- h = number of heads -- e = embedding size -- n = sequence length average :: forall e. KnownNat e => T '[e] Float32 -> Scalar Float32 average = reduceMeanAll -- | Normalise a vector. But add a small epsilon to avoid division by zero normalizer :: forall e. KnownNat e => T '[e] Float32 -> T '[e] Float32 normalizer x = mapT (⊘ (sigma + epsilon)) xmu -- so the norm of result is almost 1 where mu = average x xmu = mapT (⊝ mu) x -- so the average of xmu is 0 sigma = sqrt (average (square xmu)) -- the norm of xmu. epsilon = 0.001 -- ? -- Informally: -- mapT f x = vector y such that y_i = f (x_i) -- (the first axis) dimAsFloat :: forall e. KnownNat e => Float dimAsFloat = fromIntegral (knownNatVal (natSat @e)) -- | dot product attention on one key (k) dotAttention1 :: forall e n. KnownNat e => KnownNat n => T '[e,n] Float32 -> T '[n,e] Float32 -> T '[e] Float32 -> T '[e] Float32 dotAttention1 q v k = v ∙ softmax0 (mapT (⊘ normFactor) (q ∙ k)) where normFactor = constant (sqrt (dimAsFloat @e)) -- | dot product attention for every position dotAttention :: forall n e. KnownNat n => KnownNat e => T '[n,e] Float32 -> T '[n,e] Float32 -> T '[n,e] Float32 -> T '[n,e] Float32 dotAttention v k q = mapT (dotAttention1 (transpose01 q) v) k -- | h copies of a dense layer (the same for every copy). multiheadLinearEncoder :: forall h e. KnownNat e => KnownNat h => String -> Gen (T '[e] Float32 -> T '[h,e] Float32) multiheadLinearEncoder name = do wv <- parameterDefault ("w_" ++ name) return $ \x -> reshape (wv # x) multiheadSelfAttentionModule :: forall h n e. KnownNat n => KnownNat h => KnownNat e => String -> Gen (T '[n,e] Float32 -> T '[n,e] Float32) multiheadSelfAttentionModule nm = do ev <- multiheadLinearEncoder @h ("v" ++ nm) eq <- multiheadLinearEncoder @h ("q" ++ nm) ek <- multiheadLinearEncoder @h ("k" ++ nm) w1 <- parameterDefault ("w1" ++ nm) -- w2 <- parameterDefault ("w2" ++ nm) return $ \x -> let v = transpose01 (mapT ev x) q = transpose01 (mapT eq x) k = transpose01 (mapT ek x) r :: T '[n,h,e] Float32 r = transpose01 (zipWith3T dotAttention q k v) r' = mapT (dense @e w1 . reshape @'[h * e]) r in mapT ({-dense w2 . -}normalizer) (r' + x) -- x + mapT normalizer r' multiheadSelfAttentionModuleDecoder :: forall h n e. KnownNat n => KnownNat h => KnownNat e => String -> Gen (T '[n,e] Float32 -> T '[n,e] Float32 -> T '[n,e] Float32) multiheadSelfAttentionModuleDecoder nm = do ev <- multiheadLinearEncoder @h ("v" ++ nm) eq <- multiheadLinearEncoder @h ("q" ++ nm) ek <- multiheadLinearEncoder @h ("k" ++ nm) w1 <- parameterDefault ("w1" ++ nm) -- w2 <- parameterDefault ("w2" ++ nm) return $ \x -- comes from decoder y -- comes from encoder -> let k = transpose01 (mapT ek y) v = transpose01 (mapT ev x) q = transpose01 (mapT eq y) r :: T '[n,h,e] Float32 r = transpose01 (zipWith3T dotAttention q k v) r' = mapT (dense @e w1 . reshape @'[h * e]) r in mapT ({-dense w2 . -}normalizer) (r' + x) -- x + mapT normalizer r' feedForwardModule :: forall e. KnownNat e => String -> Gen (T '[e] Float32 -> T '[e] Float32) feedForwardModule nm = do w1 :: DenseP Float32 e e <- parameterDefault (nm ++ "w1") w2 <- parameterDefault (nm ++ "w2") return $ \x -> normalizer (x + (w2 # relu (w1 # x))) encoderModule :: forall h n e. KnownNat n => KnownNat h => KnownNat e => DropProb -> String -> T '[n,e] Float32 -> Gen (T '[n,e] Float32 -> T '[n,e] Float32) encoderModule dropProb nm positionalTensor = do drp <- mkDropout dropProb selfAtt <- multiheadSelfAttentionModule @h (nm ++ "mh") ff <- feedForwardModule (nm ++ "ff") return (mapT ff . selfAtt . (+ positionalTensor) . drp) positionalModuleSinCos :: forall n e. KnownNat e => KnownNat n => T '[n,e] Float32 positionalModuleSinCos = sin (transpose01 (broadcastT pos) * (broadcastT omega) + broadcastT phase) where pos = (cast (range @n @'B32)) :: T '[n] Float32 phase = cast ((range @e @'B32) `floorMod` constant 2) * (constant pi/2) :: T '[e] Float32 omega = constant (log 10000) * exp (constant (-2.0 / dimAsFloat @e) * cast (range @e @'B32)) -- Note I'm not dividing the frequence by 2 because integer -- division isn't implemented. Should not have any consequence. positionalModuleLearned :: KnownNat e => KnownNat n => Gen (T '[n,e] Float32) positionalModuleLearned = do e <- parameterDefault "positional" return $ let EmbeddingP x = e in x encoderStack :: forall h n e. KnownNat h => KnownNat n => KnownNat e => DropProb -> Int -> Gen (T '[n,e] Float32 -> T '[n,e] Float32) encoderStack dropProb n = do p <- positionalModuleLearned encoders <- mapM (\i -> encoderModule @h dropProb ("enc" ++ show i) p) [1..n] return (foldr (.) id encoders) -- n-ary function composition ================================================ FILE: TypedFlow/Python.hs ================================================ {-# LANGUAGE ViewPatterns #-} {-| Module : TypedFlow.Python Description : Python-generation Functions Copyright : (c) Jean-Philippe Bernardy, 2017 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental -} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UndecidableSuperClasses #-} {-# LANGUAGE UnicodeSyntax #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module TypedFlow.Python (compile, compileGen, generateFile) where import Data.Char (toLower) import Data.Proxy import Data.List (genericReplicate, ) import GHC.TypeLits import Control.Monad.State import TypedFlow.Types import TypedFlow.Broadcast (permToFun,unopInputShape) import TypedFlow.Types.Proofs import TypedFlow.Memo import Prettyprinter as PP import Prettyprinter.Render.String as PP import qualified Data.Map as M import TypedFlow.Learn import qualified Data.Sequence as S import Data.Sequence (Seq, (|>), ) import Data.Foldable first :: (t -> a) -> (t, b) -> (a, b) first f (x,y) = (f x,y) paramShape' :: VarInfo -> [Integer] paramShape' (VarInfo {varRef = Ref _ s _}) = shapeToList' s paramDType :: VarInfo -> Typ paramDType (VarInfo {varRef = Ref _ _ t}) = sTypTyp t paramName :: VarInfo -> String paramName (VarInfo {varRef = Ref {..}}) = refName generateFile :: String -> Python [VarInfo] -> IO () generateFile fname g = do putStrLn ("Parameters (total " ++ show (sum [product (paramShape' p) | p <- params]) ++ "):") forM_ params printParam writeFile fname output where (output,params) = generate g printParam p = putStrLn (paramName p ++ ": " ++ "T " ++ renderSimple (showShape' (paramShape' p)) ++ " " ++ showT (paramDType p)) named :: String -> DOC -> DOC named fname x = text (fname <> "=") <> x text :: String -> DOC text = pretty genFun :: forall b. String -> [DOC] -> Python b -> Python b genFun name args body = do gen (text "def " <> text name <> align (tuple args) <> text ":") withDOC (\b -> " " <> align b) body showTyp :: forall t. KnownTyp t => DOC showTyp = text (showT (typVal @t)) showSTyp :: forall t. STyp t -> DOC showSTyp t = knownTyp t $ showTyp @t showT :: Typ -> [Char] showT (Typ Bool _) = "tf.bool" showT (Typ Cmplx B32) = "tf.complex64" showT (Typ Cmplx B64) = "tf.complex128" showT (Typ k l) = "tf." ++ map toLower (show k) ++ drop 1 (show l) showShape' :: [Integer] -> DOC showShape' s = list (map (showDim' "None") s) showShape :: ∀ (s :: Shape). All KnownNat s => SList s -> DOC showShape s = showShape' (shapeToList'' s) showSShape :: ∀ (s :: Shape). SShape s -> DOC showSShape s = showShape' (shapeToList' s) showShapeType :: ∀ (s :: Shape). KnownShape s => DOC showShapeType = showSShape (typeSShape @s) -- | Show a shape, but "None" is replaced by "-1" showShapeMinus :: forall (s::Shape) proxy. All KnownNat s => SList' proxy s -> DOC showShapeMinus s = list (map (showDim' "-1") (shapeToList'' s)) showShapeLen :: ∀ (s::Shape). KnownLen s => DOC showShapeLen = (text . show) (listTypeLen @ s) showDim' :: String -> Integer -> DOC showDim' none n = text (if n == 514229 then none else show n) showDimM :: forall n. KnownNat n => DOC showDimM = showDim' "-1" (natVal (Proxy @ n)) showDim :: forall n. KnownNat n => DOC showDim = showDim' "None" (natVal (Proxy @ n)) showDimS :: forall n. Sat KnownNat n -> DOC showDimS Sat = showDim @n gen :: DOC -> Python () gen s = modify $ \PyState{..} -> PyState {genText=genText |> s,..} setGen :: Seq DOC -> Python () setGen d = modify $ \PyState{..} -> PyState {genText=d,..} (<--) :: Ref Int s t -> UntypedExpression -> Python () x <-- y = gen (pyVarRepr x <> text "=" <> y) renderSimple :: Doc ann -> String renderSimple = renderString . layoutPretty (LayoutOptions Unbounded) -- | save an intermediate result to a variable and save it to -- genAssignTable for future re-use. cache :: forall s t. KnownTyp t => KnownShape s => DOC -> Python DOC cache x = do let x' = renderSimple x mcache <- M.lookup x' <$> gets genAssignTable case mcache of Just y -> do -- comment ("cache hit: " <> text x') return y Nothing -> do -- comment ("cache miss") v <- newPyVar @s @t comment ("shape: " <> (showShapeType @s)) v <-- x modify $ (\g -> g {genAssignTable = M.insert x' (pyVarRepr v) (genAssignTable g)}) return (pyVarRepr v) newPyVar' :: forall s t. SShape s -> STyp t -> Python (Ref Int s t) newPyVar' s t = knownSShape s ?> (knownTyp t $ newPyVar @s @t) newId :: Python Integer newId = do n <- gets genId modify $ \PyState{..} -> PyState {genId=genId+1,..} return n newPyVar :: forall s t. KnownShape s => KnownTyp t => Python (Ref Int s t) newPyVar = do n <- newId return $ Ref (fromIntegral n) typeSShape typeSTyp pyVarInfoRepr :: VarInfo -> DOC pyVarInfoRepr i = text (varName i) pyVarRepr :: Ref Int s t -> DOC pyVarRepr (Ref n _ _) = text ("var" <> show n) tuple :: [DOC] -> DOC tuple = parens . align . sep . punctuate comma dict :: [(String,DOC)] -> DOC dict xs = braces $ align $ sep $ punctuate comma [text (show k) <> ":" <> v | (k,v) <- xs] funcall :: String -> [DOC] -> DOC funcall = funcall' . text funcall' :: DOC -> [DOC] -> DOC funcall' f args = f <> tuple args comment :: DOC -> Python () comment c = gen ("#" <> c) func :: String -> [DOC] -> [(String,DOC)] -> DOC func fname positional namedArgs = funcall fname (positional ++ map (uncurry named) namedArgs ) withDOC :: forall a. (DOC -> DOC) -> Python a -> Python a withDOC f g = do before <- gets genText setGen mempty x <- g after <- gets genText setGen (before |> f (vcat $ toList after)) return x generate :: Python [VarInfo] -> (String,[VarInfo]) generate s = (renderString (layoutPretty (LayoutOptions (AvailablePerLine 92 1)) (vcat $ toList genText)), genPyVars) where (genPyVars,PyState{..}) = runState s initPyState initPyState = PyState {genPureTable = mempty ,genAssignTable = mempty ,genText = mempty ,genId = 10000} generatePure :: forall s t. KnownTyp t => KnownShape s => T s t -> Python DOC generatePure x = do let sn = makeSn2 x mv <- snMapLookup2 sn <$> gets genPureTable case mv of Just v -> do -- comment ("gp hit:" <> v) return v Nothing -> do -- comment ("gp miss") e <- generatePure' (\s x' -> knownSShape s ?> generatePure x') typeSShape x v <- cache @s @t e modify (\g -> g {genPureTable = (snMapInsert2 sn v) (genPureTable g)}) return v genDistr :: forall s s0 t. KnownTyp t => Distribution s t -> SShape s0 -> SShape s -> DOC genDistr d sh s1 = case d of TruncatedNormalD stddev -> funcall "tf.random.truncated_normal" [showSShape (sh .+. s1), named "stddev" (float stddev), named "dtype" (showTyp @t)] UniformD low high -> funcall "tf.random.uniform" [showSShape (sh .+. s1) ,named "minval" (float low) ,named "maxval" (float high) ,named "dtype" (showTyp @t)] OrthogonalD -> funcall' (funcall "tf.keras.initializers.orthogonal" []) [named "dtype" (showTyp @t), named "shape" (showSShape (sh .+. s1))] generatePure' :: forall s t. KnownTyp t => (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> Python DOC) -> SShape s -> T s t -> Python DOC generatePure' rec sR = knownSShape sR ?> \case Unbroadcast{} -> error "broadcasting operation did not complete (Unbroadcast)!" BroadcastT _ _ _ sh x -> --- error "broadcasting operation did not complete (BroadcastT)!" do -- debug help rx <- rec sh x return (funcall "ERROR:BroadcastT" [rx]) MapT {} -> error "broadcasting operation did not complete (mapT)!" ZipT {} -> error "broadcasting operation did not complete (ZipT)!" Zip3T {} -> error "broadcasting operation did not complete (Zip3T)!" If c x y -> do rc <- rec typeSShape c rx <- rec typeSShape x ry <- rec typeSShape y return (func "tf.cond" [rc] [("true_fn", lambda0 rx) ,("false_fn", lambda0 ry)]) where lambda0 z = text "lambda: " <> z -- if broadcast_to is broken: https://github.com/tensorflow/tensorflow/issues/21901 -- DirectBroadcast s0 s1 s2 s3 x -> do -- recx <- rec (s0 .+. s2) x -- let expanded = func "tf.reshape" [recx,list (map (showDim' "-1") -- (concat [shapeToList' s0, genericReplicate (sListLength s1) 1 -- ,shapeToList' s2, genericReplicate (sListLength s3) 1 ]))] [] -- return (funcall "tf.add" [expanded, func "tf.zeros" [showSShape sR] [("dtype", showTyp @t)]]) DirectBroadcast s0 s1 s2 s3 x -> do recx <- rec (s0 .+. s2) x let expanded = func "tf.reshape" [recx,list (map (showDim' "-1") (concat [shapeToList' s0, genericReplicate (sListLength s1) 1 ,shapeToList' s2, genericReplicate (sListLength s3) 1 ]))] [] return (funcall "tf.broadcast_to" [expanded, showSShape sR]) Noise noiseId s0 s1 x -> do return $ (genDistr x s0 s1) <+> (text "# " <> integer noiseId) T op -> return $ case op of ExternalVar (Ref v _ _) -> text v Variable v -> pyVarRepr v (Constant c) -> funcall "tf.constant" [prettyT @t c, named "shape" (showSShape sR), named "dtype" (showTyp @t)] (Range n@Sat) -> (func "tf.range" [] [("start",integer 0), ("limit",integer (natVal n)), ("dtype",showTyp @t)]) Where c x y -> do rc <- rec typeSShape c rx <- rec typeSShape x ry <- rec typeSShape y return (funcall "tf.where" [rc, rx, ry]) UnOp operation s0 x -> do recx <- rec (s0 .+. unopInputShape operation) x return $ case operation of Diag _ -> funcall "tf.matrix_diag" [recx] Cast -> funcall "tf.cast" [recx,showTyp @t] StopGradient -> funcall "tf.stop_gradient" [recx] ExpM _ -> funcall "tf.linalg.expm" [recx] ZeroTriangle _ side k -> funcall ("tf.experemental.numpy.tri" ++ case side of Upper -> "u"; Lower -> "l") [recx, integer k] Conjugate -> funcall "tf.math.conj" [recx] RealPart -> funcall "tf.math.real" [recx] Axis1Op _ (SliceOp _ _ lo hi) -> recx <> list (replicate (fromIntegral (sListLength s0)) (text ":") ++ [integer lo <> text ":" <> integer hi]) Axis1Op _ (AccessOp _ idx) -> recx <> list (replicate (fromIntegral (sListLength s0)) (text ":") ++ [integer idx]) Axis1Op _ op' -> let (op,args) = case op' of SliceOp {} -> error "Python: panic: sliceop is special" AccessOp {} -> error "Python: panic: accessop is special" ReverseT _ -> ("tf.reverse",[]) OneHot depth -> ("tf.one_hot",[("dtype",showTyp @t), ("depth", showDimS depth)]) ArgMax{} -> ("tf.argmax",[("output_type",showTyp @t)]) ReduceOp _ r -> ("tf.reduce_" ++ rop, []) where rop = case r of Max -> "max" Min -> "min" Sum -> "sum" Mean -> "mean" axisName = if op == "tf.nn.softmax" then "dim" else "axis" -- use dim before TF 1.5 useAxisList = case op' of ReverseT _ -> True; _ -> False in func op [recx] ((axisName,(if useAxisList then (list . (:[])) else id) (integer (sListLength s0))):args) Float1Op op' -> funcall op (recx:args) where (op,args) = case op' of HardSigmoid -> ("tf.keras.backend.hard_sigmoid",[]) Relu -> ("tf.nn.relu",[]) ClipByValue lo hi -> ("tf.clip_by_value",[float lo,float hi]) _ -> ("tf." ++ map toLower (show op'), []) Num1Op op' -> funcall op (recx:args) where (op,args) = case op' of Negate -> ("tf.negative",[]) _ -> ("tf." ++ map toLower (show op'), []) MatMul s0 a b c x y -> do recx <- rec (s0 .+. (:*) a ((:*) b Unit)) x recy <- rec (s0 .+. (:*) b ((:*) c Unit)) y return (funcall "tf.matmul" [recx, recy]) BinOp operation s0 s1 _ s2 _ x y -> do recx <- rec (s0 .+. s1) x recy <- rec (s0 .+. s2) y return $ case operation of Simple2Op sop -> let pop = case sop of MkComplex -> "tf.complex" Add -> "tf.add" Divide -> "tf.divide" IntegerDiv -> "tf.math.floordiv" Equal -> "tf.equal" Subtract -> "tf.subtract" Multiply -> "tf.multiply" Minimum -> "tf.minimum" Maximum -> "tf.maximum" Comparision op -> "tf.math." ++ case op of Less -> "less" Greater -> "greater" LessOrEqual -> "less_equal" GreaterOrEqual -> "greater_equal" Logic op -> "tf.math.logical_" ++ case op of And -> "and" Or -> "or" FloorMod -> "tf.math.floorMod" in funcall pop [recx,recy] SigmoidCrossEntropyWithLogits -> func "tf.nn.sigmoid_cross_entropy_with_logits" [] [("labels",recx),("logits",recy)] SparseSoftmaxCrossEntropyWithLogits -> func "tf.nn.sparse_softmax_cross_entropy_with_logits" [] [("labels",recx),("logits",recy)] SoftmaxCrossEntropyWithLogits -> func "tf.nn.softmax_cross_entropy_with_logits" [] [("labels",recx),("logits",recy)] -- FIXME: use _v2 for TF 1.5 ReshapeFrom s t -> do rt <- rec s t return (funcall "tf.reshape" [rt, showShapeMinus sR]) Concat s0 s1 xs -> do let go :: forall s0 s1 ns. SShape s0 -> SShape s1 -> NP (Catable s0 s1 t) ns -> Python [DOC] go _ _ Unit = return [] go s0' s1' (Catable n y :* ys) = (:) <$> rec (s0' .+. n :* s1') y <*> go s0' s1' ys rxs <- go s0 s1 xs return (funcall "tf.concat" [list rxs, text "axis=" <> integer (sListLength s0)]) Transpose s p x -> do rx <- rec s x comment ("transpose: p = " <> text (show p) <> "; " <> text (show s)) return (func "tf.transpose" [rx] [("perm",list (map (integer . permToFun p) [0.. sListLength s-1]))]) Gather indexShape s0 m s1 x ix -> do rx <- rec (s0 .+. ((:*) m s1)) x rix <- rec (s0 .+. indexShape) ix return (func "tf.gather" [named "params" rx, named "indices" rix, named "batch_dims" (integer (sListLength s0)), named "axis" (integer (sListLength s0))] []) GatherND containerShape elementShape indexShape x ix -> do rx <- rec (containerShape .+. elementShape) x rix <- rec (indexShape *: (sListLenAsNat containerShape)) ix return (func "tf.gather_nd" [rx, rix] []) Convolution bs inChans outChans filterShape s0 x filters -> do recx <- rec ((:*) bs (s0 *: inChans)) x recFilters <- rec (filterShape .+. ((:*) inChans ((:*) outChans Unit))) filters return (func "tf.nn.convolution" [recx, recFilters] [("padding",text (show ("SAME"::String))),("data_format", text (show dataFormat))]) where dataFormat = case sListLength filterShape of 1 -> ("NWC" :: String) 2 -> "NHWC" 3 -> "NDHWC" _ -> error "convolution: more than 3 spatial dimensions are not supported!" Pool bs window typ numChans outSpatial x -> do rx <- rec ((:*) bs (zipWithMulSShapes window outSpatial .+. (:*) numChans Unit)) x return (func "tf.nn.pool" [rx, showSShape window, typ'] [("strides", showSShape window), ("padding",text (show ("SAME" :: String)))]) where typ' = text $ (show $ case typ of MaxPool -> "MAX"; AvgPool -> "AVG" :: String) Softmax _ _ x -> do rx <- rec typeSShape x return $ func "tf.nn.softmax" [rx] [("axis","1")] -- _ -> error "Python compiler: case not covered" type Python a = State PyState a generateParameters :: [VarInfo] -> Python [DOC] generateParameters genVars = do -- generate variables forM genVars $ \v -> case v of VarInfo {..} -> case varRef of Ref refId shap typ -> do ii <- case varInitial of Nothing -> return [] Just iii -> do iiii <- case knownSShape shap of Sat -> knownTyp typ $ generatePure iii return [named "initial_value" iiii] var <- newPyVar' shap typ var <-- funcall "tf.Variable" ([named "name" (string refId), named "trainable" (bool varTrainable)] ++ ii) return (pyVarRepr var) -- | Clip a gradient clipByGlobalNorm :: Float -> UntypedExpression -> UntypedExpression clipByGlobalNorm maxNorm x = funcall "tf.clip_by_global_norm" [x,float maxNorm] <> brackets (int 0) -- clip_by_global_norm returns a couple (clipped grads, global_norm) -- | Gradient of wrt. given parameters. grad :: UntypedExpression -> UntypedExpression -> UntypedExpression grad y vars = funcall "tf.gradients" [y, vars] fnToPython ::[VarInfo] -> PreparedFunction -> Python () fnToPython params PreparedFunction{pfInputs = SomeSuch placeHolders, pfOutputs = SomeSuch returned,..} = do -- we can't re-use intermediate computations from initialisers or other functions: modify $ \PyState {..} -> PyState {genPureTable = mempty, genAssignTable = M.empty,..} gen (text "@tf.function") genFun (pfName <> "_fn") (text "training_placeholder": map pyVarInfoRepr params ++ hMapToList @KnownPlaceholder (text . placeholderName) placeHolders) $ do returns <- hfor @KnownPlaceholder returned $ \ph@(PHT x) -> do r <- generatePure x return (placeholderName ph,r) gen (text "return " <> dict returns) return () gen (text pfName <> " = " <> dict [ ("function",text pfName <> "_fn"), ("batched",bool pfBatched), ("placeholders",dict (hMapToList @KnownPlaceholder (\ph -> case placeHolderRef ph of Ref nm shape typ -> (nm, dict [("shape",showSShape shape), ("dtype",showSTyp typ)])) placeHolders))]) return () toPython :: PreparedModel -> Python () toPython PreparedModel {..} = do gen (text "import tensorflow as tf") -- Static stuff: construct and initialise parameters, list placeholders, etc. genFun "mkModel" [] $ do vs <- generateParameters pmParams gen (text "return " <> dict [("batch_size", integer pmBatchSize) ,("parameters",list vs) ,("paramsdict",dict [(varName p, v) | (p,v) <- zip pmParams vs])]) -- Loss/Accur/Predict function forM_ pmFunctions (fnToPython pmParams) return () -- | Batchify and compile a model with simple input to output mapping. compile :: forall batchSize sx tx sy ty sy_ ty_ p . (KnownNat batchSize, KnownShape sx, KnownTyp tx, KnownShape sy, KnownTyp ty, KnownShape sy_, KnownTyp ty_, KnownShape p, KnownLen p) => Options -> Gen (Tensor sx tx -> Tensor sy ty -> ModelOutput ty_ p sy_) -> Python [VarInfo] compile options fGen = knownSShape (typeSShape @sy_ .+. typeSShape @p) ?> compileGen @batchSize options (sequenceA [simpleModel @p <$> fGen]) -- | Batchify and compile a model with generic input to output mapping and states compileGen :: forall bs. (KnownNat bs) => Options -> Gen [Function] -> Python [VarInfo] compileGen options model = toPython pm >> return pmParams where pm@PreparedModel{..} = prepare @bs model prettyT :: forall t. KnownTyp t => HaskType t -> DOC prettyT = case kindVal @(TypKind t) of SInt -> case bitsVal @(TypBits t) of SB32 -> int . fromIntegral SB64 -> int . fromIntegral SBool -> bool SFloat -> case bitsVal @(TypBits t) of SB32 -> float SB64 -> double data PyState = PyState {genId :: Integer ,genText :: S.Seq DOC ,genPureTable :: SSNMap2 Shape Typ T DOC -- ^ Table mapping pointers to their -- interpretations, so that sharing in the data -- structures can be exploited when generating ,genAssignTable :: M.Map String DOC -- ^ Table mapping expressions to variables, so -- that lost sharing can be recovered -- genPeeks :: [(String,UntypedExpression)] } type UntypedExpression = DOC type DOC = Doc () double :: Double -> DOC double = pretty float :: Float -> DOC float = pretty integer :: Integer -> DOC integer = pretty int :: Int -> DOC int = pretty bool :: Bool -> DOC bool = pretty string :: String -> DOC string = dquotes . text ================================================ FILE: TypedFlow/TF.hs ================================================ {-# LANGUAGE InstanceSigs #-} {-| Module : TypedFlow.TF Description : Binding to tensorflow functions Copyright : (c) Jean-Philippe Bernardy, 2017 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental This module provides direct access to the most commonly used TensorFlow functions. Higher-level functions are not defined here. -} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UnicodeSyntax #-} {-# LANGUAGE ApplicativeDo #-} {-# LANGUAGE NoStarIsType #-} module TypedFlow.TF ( -- * Variables, Parameters -- ** Parameters parameter', parameter, parameterDefault, ParamWithDefault(..), -- getParameters, -- ** Persistent variables persistent, modifyPersistent, -- ** Placeholders and outputs -- placeholder, -- peekAt, -- peekAtMany, -- * Operations -- ** Constants zeros, ones, eye, constant, -- ** indexwise unary operators round, sigmoid, relu, floor, square, -- ** Indexwise binary operators addN, (⊕), (⊝), (⊙), (⊘), equal, minT, maxT, -- ** Products (∙), (·), matmul, -- ** Reducers reduceMeanAll, reduceSumAll, reduceMinAll, reduceMaxAll, reduceSum, reduceMean, reduceMin, reduceMax, -- argmax, argmax0, argmax1, softmax0, softmax1, -- ** Gradients -- grad, -- clipByGlobalNorm, clipByValue, -- ** Indexing last0, nth0, nth0', lookupT, lookupManyT, gather, range, reverseT, -- ** Split and concatenate slice, slice0, slice1, litStack0, stack0, unstack0, stack1, concatT, concat0, concat1, consT0, snocT0, headT0, tailT0, initT0, -- ** Reshaping expandDim, expandDim0, squeeze0, expandDim1, flatten2, flatten3, flatten12, flattenN2, inflate2, inflate3, inflate12, reshape, flattenAll, inflateAll, -- ** Transposition transposeN, transposeN', transpose01, transposeN01, -- ** Sequences sequenceMask, -- ** Convolutions convolution, -- ** Misc norm, normalize, stopGradient, cast, oneHot0, oneHot1, -- ** complex numbers expm, conjugate, realPart, -- ** Triangular and band Matrices tril, triu, fillTriangular, fillUpperTriangular, -- ** Testing conditions if_, where_, lessThan, -- * Contrib -- ** Mapping mapT, zipWithT, zipWith3T, mapTT, zipWithTT, -- ** Losses sigmoidCrossEntropyWithLogits, softmaxCrossEntropyWithLogits, sparseSoftmaxCrossEntropyWithLogits, -- ** Initializers noise, Distribution(..), varianceScaling, glorotUniform, -- ** Heterogeneous vectors repeatT, -- ** Heterogeneous heterogeneous vectors repeatHT ) where import Prelude hiding (RealFrac(..)) import GHC.TypeLits import Data.Proxy import TypedFlow.Types import TypedFlow.Types.Proofs import TypedFlow.Abstract import TypedFlow.Broadcast -- | Repeat a flexible-shape constant vector to form a heterogeneous tensor vector. repeatT :: forall (ss :: [Shape]) t. All KnownShape ss => KnownLen ss => (forall s. KnownShape s => T s t) -> HTV t ss repeatT f = zs (typeSList @ss) where zs :: forall (s :: [Shape]). All KnownShape s => SList s -> HTV t s zs Unit = Unit zs (_ :* n) = F f :* zs n -- | Repeat a flexible-shape constant vector to form a heterogeneous tensor vector. repeatHT :: forall ss. All KnownPair ss => KnownLen ss => (forall s t. KnownShape s => KnownTyp t => T s t) -> HHTV ss repeatHT f = zs (typeSList @ss) where zs :: forall s. All KnownPair s => SList s -> HHTV s zs Unit = Unit zs (_ :* n) = Uncurry f :* zs n -- | Declare a parameter to optimize. parameter' :: ∀ (shape :: Shape) t. (KnownTyp t,KnownShape shape) => String -> T shape t -> Gen (T shape t) parameter' = persistent True -- | Create a parameter. parameter :: forall p. KnownTensors p => String -> Gen p -> Gen p parameter s p = travTensor parameter' s =<< p -- | Declare variable which persists between calls to session.run. persistent :: ∀ (shape :: Shape) t. (KnownTyp t,KnownShape shape) => Bool -> String -> T shape t -> Gen (T shape t) persistent trainable name initial = do T . ExternalVar <$> GPVariable trainable name (Just initial) -- | Modify a mutable tensor. Attention: for the assignment to happen, -- the resulting tensor must be evaluated! modifyPersistent :: (KnownShape s,KnownTyp t) => T s t -> T s t -> Gen (T s t) modifyPersistent (T (Variable v)) x = GPModify v x -- FIXME: pattern matching here is poor style. -- type family AddSpatialDims xs ys where -- AddSpatialDims '[x] '[] = '[x] -- AddSpatialDims (x ': xs) (y ': ys) = (x+(y-1)) ': AddSpatialDims xs ys -- -- | Convolution operation with no padding (applying the filter only on positions where the input is fully defined) -- convolutionValid :: forall outputChannels filterSpatialShape inChannels s t. -- KnownLen filterSpatialShape -- => Length filterSpatialShape <= 3 -- => ((1 + Length filterSpatialShape) ~ Length s) -- the last dim of s is the batch size -- => T (inChannels ': AddSpatialDims s filterSpatialShape) t -- ^ input tensor (batched) -- -> T ('[outputChannels,inChannels] ++ filterSpatialShape) t -- ^ filters -- -> T (outputChannels ': s) t -- convolutionValid = untypedConvolution "VALID" -- poolNC :: forall dim s inputSpatialShape channels batchSize t. -- (inputSpatialShape ~ Take dim s, '[batchSize] ~ Drop dim s) => -- T ('[channels] ++ s) t -> -- Vec dim -> String -> String -> -- T ('[channels] ++ s) t -- poolNC (T input) windowShape poolingType padding = -- T (funcall "tf.nn.pool" [input,list (map float (vecToList windowShape)),text poolingType,text padding,named "data_format" (text "NWC")]) -- Difficulty: relate windowSize, inputSpatialShape, outputSpatialShape --------------------------- -- Contrib data VarianceScaleMode = VSFanIn | VSFanOut | VSAvg data Distrib = NormalDistr | UniformDistr -- | Random tensor with variance scaling according to deeplearning lore. varianceScaling :: forall inDim outDim t. KnownNat inDim => (KnownNat outDim, KnownFloat t) => Float -> VarianceScaleMode -> Distrib -> Gen (Tensor '[inDim,outDim] t) varianceScaling factor mode distr = noise $ case distr of UniformDistr -> UniformD (-limit) limit NormalDistr -> TruncatedNormalD limit where fan_in = fromIntegral (natVal (Proxy @inDim)) fan_out = fromIntegral (natVal (Proxy @outDim)) n = max 1 $ case mode of VSFanIn -> fan_in VSFanOut -> fan_out VSAvg -> (fan_in + fan_out) / 2 limit = sqrt ((case distr of NormalDistr -> 1.3; UniformDistr -> 3) * factor / n) glorotUniform :: forall inDim outDim t. KnownNat inDim => (KnownNat outDim, KnownBits t) => Gen (Tensor '[outDim,inDim] ('Typ 'Float t)) glorotUniform = varianceScaling 1 VSAvg UniformDistr -- | 'cons' an element and an array (in the first dimension) consT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n => T s t -> T (n ': s) t -> T (n+1 ': s) t consT0 x xs = plusComm @1 @n #> concat0 (expandDim0 x) xs -- | 'snoc' an element and an array (in the first dimension) snocT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n => KnownLen s => T (n ': s) t -> T s t -> T (n+1 ': s) t snocT0 xs x = concat0 xs (expandDim0 x) headT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n => T (n+1 ': s) t -> T (s) t headT0 xs = nth0 0 xs tailT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n => T (n+1 ': s) t -> T (n ': s) t tailT0 xs = incrPos @n #> -- 0 < n+1 plusMinusAssoc @n @1 @1 #> -- (n+1) - 1 = -- n+ (1 - 1) slice0 @1 @(n+1) xs initT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n => T (n+1 ': s) t -> T (n ': s) t initT0 xs = plusMono @n @1 #> -- n <= n+1 slice0 @0 @n xs ---------------- -- Helpers -- | Product of a matrix of weights with a vector. (∙) :: (KnownNumeric t, KnownNat cols, KnownNat rows, KnownTyp t) => Tensor '[cols, rows] t -> Tensor '[cols] t -> Tensor '[rows] t m ∙ v = squeeze0 (matmul (expandDim0 v) m) infixl 7 ∙ -- | Dot product between two vectors. (·) :: ∀ n t. (KnownNumeric t, KnownNat n) => Tensor '[n] t -> Tensor '[n] t -> Tensor '[] t x · y = reduceSum0 (x ⊙ y) infixl 7 · -- | 2-Norm of a vector norm :: KnownBits t => KnownNat n => T '[n] (Flt t) -> Scalar (Flt t) norm = frobNorm -- | 2-Norm of a tensor frobNorm :: KnownShape s => KnownBits t => T s (Flt t) -> Scalar (Flt t) frobNorm = sqrt . reduceSumAll . square normalize :: (KnownNat n, KnownBits t) => T '[n] (Flt t) -> T '[n] (Flt t) normalize v = mapT (/ (norm v + epsilon)) v where epsilon = 1.0e-8 fillTriangular :: forall n l t. (KnownNat n, KnownNat l, KnownNumeric t, (((l+l)-n) ~ (n*n)), n <= l) => Tensor '[l] t -> Tensor '[n,n] t fillTriangular x = plusMinusAssoc @l @l @n #> tril 0 (inflate2 (concat0 x rr)) where rr :: Tensor '[l - n] t rr = subIneq @l @n #> slice0 @0 @(l-n) (reverseT x) -- @lookupManyT def indices array@ lokup indices in array, returning def if the index is -1 lookupManyT :: forall s n t. KnownNat n => KnownShape s => (KnownNumeric t) => Scalar t -> T s Int32 -> T '[n] t -> T s t lookupManyT def indices array = appRUnit @s #> mapTT @s (\idx -> where_ (equal idx (-1)) def (lookupT idx array)) indices -- | A flexible upper-triangular matrix function: fill the upper triangle with l elements. fillUpperTriangular :: forall n l t. KnownNumeric t => KnownNat n => KnownNat l => T '[l] t -> T '[n,n] t fillUpperTriangular x = zipWithTT @'[n,n] (\i j -> let idx :: Scalar Int32 idx = ((i * (2 * n - i - 3)) `floorDiv` 2 + j - 1) -- The index to lookup in the input array. It is computed from the formula: -- Output[i,j] = (j-i-1) + ∑_k^(i-1) (n-k) -- -- The term j-i-1 is the distance from the upper diagonal. -- The sum is the number of elements in the previous rows in where_ (((j - i) `greaterThan` 0) `logicAnd` (idx `lessThan` l)) (lookupT idx x) zeros) range0 range1 where n, l :: Scalar Int32 n = constant (fromIntegral (natVal (Proxy @n))) l = constant (fromIntegral (natVal (Proxy @l))) -- "j" index range1 :: forall n m w. (KnownNat n, KnownNat m) => KnownBits w => T '[n,m] ('Typ 'Int w) range1 = broadcastT range -- "i" index range0 :: forall n m w. (KnownNat n, KnownNat m) => KnownBits w => T '[n,m] ('Typ 'Int w) range0 = transpose01 range1 ------------------------- -- Generic parameters -- | Create a parameter and initialize it with a suitable default for its type. Control the exact initializer using 'parameter'. parameterDefault :: forall p. ParamWithDefault p => String -> Gen p parameterDefault name = parameter name defaultInitializer -- flattenHTV :: KnownTyp t => All KnownShape xs => HTV t xs -> Tensor '[Sum (Ap (FMap CProduct) xs)] t -- flattenHTV Unit = zeros -- flattenHTV (F x :* xs) = concat0 (flattenAll x) (flattenHTV xs) -- class CProduct (xs :: [Nat]) -- instance Fun CProduct where type Ap CProduct xs = Product xs -- inflateHTV :: ∀ xs s t. (All KnownShape xs, KnownLen s, KnownLen xs) => -- Tensor '[Sum (Ap (FMap CProduct) xs)] t -> Gen (HTV t xs) -- inflateHTV (T x) = do -- v <- newVar -- gen (v <> text " = " <> funcall "tf.split" [x, showShape' (prodshape @xs shapeSList), text "axis=0"]) -- return (mkArr @xs 0 shapeSList v) -- where mkArr :: forall zs. All KnownShape zs => Int -> SList zs -> DOC -> HTV t zs -- mkArr _ LZ _ = Unit -- mkArr i (LS _ n) v = F (unsafeReshape (T (v <> brackets (int i)) )):* mkArr (succ i) n v -- prodshape :: forall zs. All KnownShape zs => SList zs -> [Integer] -- prodshape LZ = [] -- prodshape (LS xx xs) = product (shapeToList' (shapeSListProxy xx)) : prodshape xs -- -- | Gradient of wrt. given parameters. -- grad' :: KnownLen xs => T s Float32 -> HHTV xs -> Gen (HHTV xs) -- grad' (T y) vars = do -- v <- newVar -- v <-- funcall "tf.gradients" [y, list (htoList (hmap (\(Uncurry (T x)) -> K x) vars)) ] -- return (mkArr 0 shapeSList v) -- where mkArr :: forall xs. Int -> SList xs -> DOC -> HHTV xs -- mkArr _ LZ _ = Unit -- mkArr i (LS _ n) v = Uncurry (T (v <> brackets (int i))) :* mkArr (succ i) n v ================================================ FILE: TypedFlow/Types/Proofs.hs ================================================ {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UndecidableSuperClasses #-} {-# LANGUAGE UnicodeSyntax #-} {-# LANGUAGE CPP #-} #if __GLASGOW_HASKELL__ >= 806 {-# LANGUAGE NoStarIsType #-} #endif {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module TypedFlow.Types.Proofs where import Prelude hiding (RealFrac(..)) import GHC.TypeLits import Data.Proxy import TypedFlow.Types hiding (T) import Data.Type.Equality import Unsafe.Coerce import Data.Kind (Type) class SingEq s where testEq :: forall a b. s a -> s b -> Maybe (a :~: b) instance SingEq (Sat KnownNat) where testEq :: forall n m. Sat KnownNat n -> Sat KnownNat m -> Maybe (n :~: m) testEq = testNatEqual natValS :: forall m. Sat KnownNat m -> Integer natValS Sat = natVal (Proxy @m) testNatEqual :: Sat KnownNat m -> Sat KnownNat n -> Maybe (m :~: n) testNatEqual m n = if natValS m == natValS n then Just (unsafeCoerce Refl) else Nothing instance SingEq f => SingEq (NP f) where testEq Unit Unit = Just Refl testEq (x :* xs) (y :* ys) = case (testEq x y, testEq xs ys) of (Just Refl, Just Refl) -> Just Refl _ -> Nothing testEq _ _ = Nothing instance SingEq SKind where testEq SBool SBool = Just Refl testEq SInt SInt = Just Refl testEq SFloat SFloat = Just Refl testEq _ _ = Nothing instance SingEq SNBits where testEq SB32 SB32 = Just Refl testEq SB64 SB64 = Just Refl testEq _ _ = Nothing instance SingEq STyp where testEq (STyp k b Refl) (STyp k' b' Refl) = case (testEq k k', testEq b b') of (Just Refl, Just Refl) -> Just Refl _ -> Nothing -- | Use a reified equality relation (#>) :: (a :~: b) -> ((a ~ b) => k) -> k Refl #> k = k infixr 0 #> -- | Use a reified arbitrary predicate (?>) :: Sat constraint a -> (constraint a => k) -> k Sat ?> k = k infixr 0 ?> -- | Use a reified arbitrary constraint (??>) :: Dict constraint -> (constraint => k) -> k Dict ??> k = k infixr 0 ??> productS :: forall s. SShape s -> Sat KnownNat (Product s) productS s = knownSShape s ?> knownProduct @s ?> Sat plusComm :: forall x y. (x + y) :~: (y + x) plusComm = unsafeCoerce Refl plusCommS :: forall x y px py. px x -> py y -> (x + y) :~: (y + x) plusCommS _ _ = plusComm @x @y plusAssoc :: forall x y z. (x + y) + z :~: x + (y + z) plusAssoc = unsafeCoerce Refl plusAssocS :: forall x y z px py pz. px x -> py y -> pz z -> ((x + y) + z) :~: (x + (y + z)) plusAssocS _ _ _ = plusAssoc @x @y @z prodAssoc :: forall x y z. (x * y) * z :~: x * (y * z) prodAssoc = unsafeCoerce Refl prodAssocS :: forall x y z px py pz. px x -> py y -> pz z -> ((x * y) * z) :~: (x * (y * z)) prodAssocS _ _ _ = prodAssoc @x @y @z prodCommS :: forall x y px py. px x -> py y -> (x * y) :~: (y * x) prodCommS _ _ = unsafeCoerce Refl termCancelation :: forall a b. (a + b) - b :~: a termCancelation = plusMinusAssoc @a @b @b #> cancelation @b #> Refl plusMinusAssoc :: forall x y z. (x + y) - z :~: x + (y - z) plusMinusAssoc = unsafeCoerce Refl cancelation :: (a - a) :~: 0 cancelation = unsafeCoerce Refl plusMono :: forall a b. (a <=? (a+b)) :~: 'True plusMono = unsafeCoerce Refl succPos :: (1 <=? 1+j) :~: 'True -- CmpNat 0 (1 + n) :~: 'LT succPos = unsafeCoerce Refl succPosProx2 :: forall n proxy a. proxy n a -> (0 :<: (1+n)) succPosProx2 _ = succPos @n prodHomo :: forall x y. Product (x ++ y) :~: Product x * Product y prodHomo = unsafeCoerce Refl prodHomoS :: forall x y px py. px x -> py y -> ((Product (x ++ y) :~: (Product x * Product y))) prodHomoS _ _ = prodHomo @x @y knownProduct' :: forall s f. All KnownNat s => NP f s -> Sat KnownNat (Product s) knownProduct' Unit = Sat knownProduct' (_ :* n) = knownProduct' n ?> Sat knownProduct :: forall s. KnownShape s => Sat KnownNat (Product s) knownProduct = knownProduct' @s typeSList knownSumS :: forall s. NP (Sat KnownNat) s -> Sat KnownNat (Sum s) knownSumS Unit = Sat knownSumS (Sat :* n) = knownSumS n ?> Sat knownSum' :: forall s f. All KnownNat s => NP f s -> Sat KnownNat (Sum s) knownSum' proxies = knownSumS (allKnown' proxies) knownSum :: forall s. KnownShape s => Sat KnownNat (Sum s) knownSum = knownSum' @s typeSList knownPlus :: forall m n. KnownNat m => KnownNat n => Sat KnownNat (m + n) knownPlus = Sat takeDrop :: forall s n. (PeanoNat n <= Length s) => (Take n s ++ Drop n s) :~: s takeDrop = unsafeCoerce Refl lengthHomo :: forall x y. Length (x ++ y) :~: Length x + Length y lengthHomo = unsafeCoerce Refl lengthHomoS :: forall x y proxyx proxyy. proxyx x -> proxyy y -> ((Length (x ++ y) :~: (Length x + Length y))) lengthHomoS _ _ = lengthHomo @x @y lengthInit :: forall s. (0 < Length s) => SList s -> ((Length (Init s) + 1) :~: Length s) lengthInit x = lengthHomo @(Init s) @'[Last s] #> initLast x #> Refl type a :<=: b = ((a <=? b):~: 'True) type i :<: j = (i+1) :<=: j incrPos :: forall x. 1 :<=: (x+1) incrPos = unsafeCoerce Refl subIneq :: forall x k. (x - k) :<=: x subIneq = unsafeCoerce Refl incrCong :: forall x y. ((x+1) ~ (y+1)) => x :~: y incrCong = unsafeCoerce Refl initLast :: forall s. {-(0 < Length s) => FIXME -} SList s -> ((Init s ++ '[Last s]) :~: s) initLast Unit = error "initLast': does not hold on empty lists" initLast ((:*) _ Unit) = Refl initLast ((:*) _ ((:*) y ys)) = initLast ((:*) y ys) #> Refl initLast' :: forall s. {-(0 < Length s) => FIXME -} KnownShape s => ((Init s ++ '[Last s]) :~: s) initLast' = initLast (typeSList @s) appRUnit :: forall s. (s ++ '[]) :~: s appRUnit = unsafeCoerce Refl appAssoc :: ((xs ++ ys) ++ zs) :~: (xs ++ (ys ++ zs)) appAssoc = unsafeCoerce Refl appAssocS :: forall xs ys zs proxy1 proxy2 proxy3. proxy1 xs -> proxy2 ys -> proxy3 zs -> (((xs ++ ys) ++ zs) :~: (xs ++ (ys ++ zs))) appAssocS _ _ _ = appAssoc @xs @ys @zs knownLast' :: All KnownNat s => SList s -> (KnownNat (Last s) => k) -> k knownLast' Unit _ = error "knownLast: does not hold on empty lists" knownLast' ((:*) _ Unit) k = k knownLast' ((:*) _ ((:*) y xs)) k = knownLast' ((:*) y xs) k knownLast :: forall s k. KnownShape s => (KnownNat (Last s) => k) -> k knownLast = knownLast' @s typeSList knownInit' :: All KnownNat s => SList s -> Sat KnownShape (Init s) knownInit' Unit = error "knownLast: does not hold on empty lists" knownInit' ((:*) _ Unit) = Sat knownInit' ((:*) _ ((:*) y xs)) = knownInit' ((:*) y xs) ?> Sat knownInit :: forall s. KnownShape s => Sat KnownShape (Init s) knownInit = knownInit' @s typeSList knownTail' :: forall x s k. All KnownNat s => SList (x ': s) -> (KnownShape s => k) -> k knownTail' ((:*) _ Unit) k = k knownTail' ((:*) _ ((:*) y xs)) k = knownTail' ((:*) y xs) k knownTail :: forall s x xs k. (s ~ (x ': xs), KnownShape s) => (KnownShape xs => k) -> k knownTail = knownTail' @x @xs typeSList knownAppendS :: forall s t pt. (All KnownNat s, KnownShape t) => SList s -> pt t -> Sat KnownShape (s ++ t) knownAppendS Unit _t = Sat knownAppendS ((:*) _ n) t = knownAppendS n t ?> Sat knownAppend :: forall s t. (KnownShape s, KnownShape t) => Sat KnownShape (s ++ t) knownAppend = knownAppendS (typeSList @s) (Proxy @t) -- knownFmap' :: forall f xs. SList xs -> SList (Ap (FMap f) xs) -- knownFmap' Unit = Unit -- knownFmap' ((:*) x n) = (:*) Proxy (knownFmap' @f n) knownSList :: NP proxy xs -> Sat KnownLen xs knownSList Unit = Sat knownSList (_ :* n) = knownSList n ?> Sat knownSShape :: SShape xs -> Sat KnownShape xs knownSShape Unit = Sat knownSShape ((:*) Sat s) = knownSShape s ?> Sat data DimExpr (a :: Nat) (x :: Nat) (b :: Nat) where ANat :: Sat KnownNat x -> DimExpr a x (a * x) (:*:) :: DimExpr a x b -> DimExpr b y c -> DimExpr a (x*y) c knownOutputDim :: forall a x b. Sat KnownNat a -> DimExpr a x b -> Sat KnownNat b knownOutputDim a (ANat x) = satMul a x knownOutputDim a (x :*: y) = knownOutputDim (knownOutputDim a x) y dimSat :: DimExpr a x b -> Sat KnownNat x dimSat (ANat s) = s dimSat (x :*: y) = dimSat x `satMul` dimSat y normDim :: forall ws xs ys. DimExpr ws xs ys -> (ws * xs) :~: ys normDim (ANat _) = Refl normDim (a :*:b) = normDim a #> normDim b #> prodAssocS (Proxy @ws) (dimSat a) (dimSat b) #> Refl data ShapeExpr (a :: Nat) (x :: Shape) (b :: Nat) where Single :: DimExpr a x b -> ShapeExpr a '[x] b AShape :: SShape x -> ShapeExpr a x (a * Product x) (:++:) :: ShapeExpr a x b -> ShapeExpr b y c -> ShapeExpr a (x++y) c infixr 5 :++: infixr 5 *:! infixr 5 !:* (!:*) :: DimExpr a x b -> ShapeExpr b xs c -> ShapeExpr a (x ': xs) c x !:* xs = Single x :++: xs (*:!) :: ShapeExpr a xs b -> DimExpr b x c -> ShapeExpr a (xs ++ '[x]) c xs *:! x = xs :++: Single x exprSShape :: forall a x b. ShapeExpr a x b -> SShape x exprSShape (AShape s) = s exprSShape (Single x) = dimSat x ?> typeSShape exprSShape (x :++: y) = exprSShape x .+. exprSShape y normShape :: forall ws xs ys. ShapeExpr ws xs ys -> (ws * Product xs) :~: ys normShape (Single x) = normDim x normShape (AShape _) = Refl normShape (l :++: r) = normShape l #> normShape r #> prodHomoS (exprSShape l) (exprSShape r) #> prodAssocS (Proxy @ws) (productS (exprSShape l)) (productS (exprSShape r)) #> Refl -- r :: normShape b y ys ----> (b * y) ~ ys (1) -- l :: normShape ws x b ----> (ws * x) ~ b (2) -- subst (2) in (1): ((ws * x) * y) ~ ys -- assoc: (ws * (x * y)) ~ ys decideProductEq1 :: forall xs zs. ShapeExpr 1 xs zs -> Product xs :~: zs decideProductEq1 a = case normShape a of Refl -> Refl type ShapeX = ShapeExpr 1 decideProductEq :: ShapeExpr 1 xs zs -> ShapeExpr 1 ys zs -> Product xs :~: Product ys decideProductEq l r = case decideProductEq1 l of Refl -> case decideProductEq1 r of Refl -> Refl unsafePositive :: (1 <=? n) :~: 'True unsafePositive = unsafeCoerce Refl sucPred :: ((1 <=? n) ~ 'True) => (n - 1) + 1 :~: n sucPred = unsafeCoerce Refl -- data ORDEQ p a b where -- LT, GT :: ORDEQ a b -- EQ :: p -> ORDEQ a a data NatExpr n where NEVar :: Int -> NatExpr n (::+) :: NatExpr m -> NatExpr n -> NatExpr (m+n) (::*) :: NatExpr m -> NatExpr n -> NatExpr (m*n) data NatSum n where NSZero :: NatSum 0 NSAdd :: NatProd m -> NatSum n -> NatSum (m+n) data NatProd n where NPUnit :: Sat KnownNat k -> NatProd k NPTimes :: Sat KnownNat m -> Int -> NatProd n -> NatProd (m*n) sortProd :: NatProd n -> NatProd n sortProd (NPUnit k) = NPUnit k sortProd (NPTimes x xId y) = insertProd x xId (sortProd y) where insertProd :: Sat KnownNat m -> Int -> NatProd n -> NatProd (m*n) insertProd x xId rest = case rest of (NPUnit k) -> NPTimes x xId rest (NPTimes y yId ys) -> if xId <= yId then NPTimes x xId rest else prodAssocS x y ys #> prodCommS x y #> prodAssocS y x ys #> NPTimes y yId (insertProd x xId ys) sortSum :: NatSum n -> NatSum n sortSum NSZero = NSZero sortSum (NSAdd x y) = insertTerm (sortProd x) (sortSum y) where insertTerm :: NatProd m -> NatSum n -> NatSum (m+n) insertTerm p rest = case rest of NSZero -> NSAdd p NSZero NSAdd q qs -> case compareTerms p q of Right p' -> plusAssocS p q qs #> NSAdd p' qs Left False -> NSAdd p rest Left True -> plusAssocS p q qs #> plusCommS p q #> plusAssocS q p qs #> NSAdd q (insertTerm p qs) compareTerms :: NatProd n -> NatProd m -> Either Bool (NatProd (n+m)) compareTerms (NPUnit Sat) (NPUnit Sat) = Right (NPUnit Sat) compareTerms (NPUnit _) (NPTimes _ _ _) = Left False compareTerms (NPTimes _ _ _) (NPUnit _) = Left True compareTerms (NPTimes x xId xs) (NPTimes y yId ys) = case testEq x y of Nothing -> Left (natValS x <= natValS y) Just Refl -> case compareTerms xs ys of Left x -> Left x Right p -> distrLS x xs ys #> Right (NPTimes x xId p) distrLS :: forall a b c px py pz. px a -> py b -> pz c -> a * (b + c) :~: (a * b + a * c) distrLS = unsafeCoerce Refl distrRS :: forall a b c px py pz. px a -> py b -> pz c -> (a + b) * c :~: ((a * c) + (b * c)) distrRS = unsafeCoerce Refl expandProd :: NatSum m -> NatSum n -> NatSum (m*n) expandProd NSZero _ = NSZero expandProd _ NSZero = NSZero expandProd (NSAdd a b) c = distrRS a b c #> expandSum (expandP' a c) (expandProd b c) expandP' :: NatProd m -> NatSum n -> NatSum (m*n) expandP' _ NSZero = NSZero expandP' p (NSAdd q a) = distrLS p q a #> NSAdd (expandPP p q) (expandP' p a) expandPP :: NatProd m -> NatProd n -> NatProd (m*n) expandPP (NPUnit k) (x) = expandKP k x expandPP (NPTimes x xId y) z = prodAssocS x y z #> NPTimes x xId (expandPP y z) expandKP :: Sat KnownNat k -> NatProd n -> NatProd (k*n) expandKP Sat (NPUnit Sat) = NPUnit Sat expandKP k@Sat (NPTimes x xId y) = prodAssocS k x y #> prodCommS k x #> prodAssocS x k y #> NPTimes x xId (expandKP k y) expandSum :: NatSum m -> NatSum n -> NatSum (m+n) expandSum NSZero x = x expandSum (NSAdd x y) z = plusAssocS x y z #> NSAdd x (expandSum y z) natRec :: forall (n :: Nat) (p :: Nat -> Type). KnownNat n => p 0 -> (forall (m :: Nat). p m -> p (m+1)) -> p n natRec z s = case natVal (Proxy @n) of 0 -> unsafeCoerce z _ -> case unsafePositive @n of Refl -> case sucPred @n of Refl -> s @(n-1) (natRec @(n-1) @p z s) data CountRes n where CountRes :: Integer -> V n Integer -> CountRes n vcount :: forall n. KnownNat n => V n Integer vcount = case natRec @n (CountRes (natVal (Proxy @n)-1) VUnit) (\(CountRes m xs) -> plusCommS (Proxy @1) (F xs) #> CountRes (m-1) (m :** xs)) of CountRes _ x -> x data V n a where VUnit :: V 0 a (:**) :: a -> V n a -> V (1+n) a infixr 5 :** deriving instance (Functor (V n)) instance KnownNat n => Applicative (V n) where pure x = fmap (const x) (vcount @n) VUnit <*> VUnit = VUnit (f :** fs) <*> (a :** as) = succPosProx2 fs #> (f a :** (fs <*> unsafeCoerce as)) ================================================ FILE: TypedFlow/Types.hs ================================================ {-# LANGUAGE QuantifiedConstraints #-} {-# LANGUAGE CPP #-} #if __GLASGOW_HASKELL__ >= 806 {-# LANGUAGE NoStarIsType #-} #endif {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE UndecidableSuperClasses #-} {-# LANGUAGE UnicodeSyntax #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ApplicativeDo #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} module TypedFlow.Types where import GHC.TypeLits import Data.Proxy -- import Control.Monad.State -- import Control.Monad.RWS (RWS(..), local, ask, tell) import Data.Kind (Constraint,Type) import qualified Data.Int as Hask import Data.Type.Equality import Data.Monoid hiding (Sum,Product,Last,All,Ap) import Data.Complex newtype (∘) f (g :: k -> k2) (a::k) where Comp :: forall f g a. f (g a) -> (f ∘ g) a type Sat = (∘) Dict type Sat' f x = f x data Dict :: Constraint -> Type where Dict :: a => Dict a pattern Sat :: forall k (g :: k -> Constraint) (a :: k). () => g a => (∘) Dict g a -- the second context is the PROVIDED constraint! pattern Sat = Comp Dict instance (Show (Sat a b)) where show _ = "Sat" proxySat :: forall k (b::k) (a :: k -> Constraint) proxy. a b => proxy b -> Sat a b proxySat _ = Sat natSat :: forall n. KnownNat n => Sat KnownNat n natSat = Sat -- type i < j = CmpNat i j ~ 'LT type i < j = (i+1) <= j -- type i <= j = (i <=? j) ~ 'True type family Product xs where Product '[] = 1 Product (x ': xs) = x * Product xs type family Sum xs where Sum '[] = 0 Sum (x ': xs) = x + Sum xs type family (++) xs ys where '[] ++ xs = xs (x ': xs) ++ ys = x ': (xs ++ ys) type family Tail xs where Tail (x ': xs) = xs type family Last xs where Last '[x] = x Last (x ': xs) = Last xs type family Init xs where Init '[x] = '[] Init (x ': xs) = x ': Init xs type family Length xs where Length '[] = 0 Length (x ': xs) = 1 + Length xs type family Reverse' xs ys where Reverse' '[] ys = ys Reverse' (x ': xs) ys = Reverse' xs (x ': ys ) type family Reverse xs where Reverse xs = Reverse' xs '[] -- From: https://www.cs.ox.ac.uk/projects/utgp/school/andres.pdf data NP f (xs :: [k]) where Unit :: NP f '[] (:*) :: f x -> NP f xs -> NP f (x ': xs) deriving instance (forall x. Show (f x)) => Show (NP f xs) type SList' = NP (.+.) = appSList infixr 5 .+. infixr 5 *: infixr 5 :* (*:) :: forall x xs f. NP f xs -> f x -> NP f (xs ++ '[x]) xs *: x = appSList xs (x :* Unit) hlookup :: Axis n xs -> NP f xs -> f (At n xs) hlookup AxZero (x :* _) = x hlookup (AxSucc n) (_ :* xs) = hlookup n xs newtype I a = I a newtype K a x = K a type HList = NP I pattern HSingle :: f a -> NP f '[a] pattern HSingle x = x :* Unit pattern VecSing :: Tensor s t -> HTV t '[s] pattern VecSing t1 = F t1 :* Unit pattern VecPair :: Tensor s t -> Tensor s' t -> HTV t '[s,s'] pattern VecPair t1 t2 = F t1 :* F t2 :* Unit pattern VecTriple :: Tensor s t -> Tensor s' t -> Tensor s3 t -> HTV t '[s,s',s3] pattern VecTriple t1 t2 t3 = F t1 :* F t2 :* F t3 :* Unit type family All (c :: k -> Constraint) (xs :: [k]) :: Constraint where All c '[] = () All c (x ': xs) = (c x, All c xs) knownAll :: forall constraint s k. NP (Sat constraint) s -> (All constraint s => KnownLen s => k) -> k knownAll Unit k = k knownAll (Sat :* xs) k = knownAll xs $ k allKnown' :: forall constraint s proxy. All constraint s => NP proxy s -> NP (Sat constraint) s allKnown' Unit = Unit allKnown' (_ :* xs) = Sat :* allKnown' xs allKnown :: forall k s. KnownLen s => All k s => NP (Sat k) s allKnown = allKnown' typeSList data SomeSuch k f where SomeSuch :: k x => f x -> SomeSuch k f class Fun (c :: k -> Constraint) where -- FIXME: use type, not constraint? type Ap c (t :: k) :: l class Cons (x :: k) (xs :: [k]) instance Fun (Cons x) where type Ap (Cons x) xs = x ': xs class Snoc (x :: k) (xs :: [k]) instance Fun (Snoc x) where type Ap (Snoc x) '[] = '[x] type Ap (Snoc x) (y ': ys) = y ': Ap (Snoc x) ys class FMap (c :: k -> Constraint) (xs :: [k]) where instance Fun c => Fun (FMap c) where type Ap (FMap c) '[] = '[] type Ap (FMap c) (x ': xs) = Ap c x ': Ap (FMap c) xs mapFMap :: forall g f xs. (forall x. f x -> f (Ap g x)) -> NP f xs -> NP f (Ap (FMap g) xs) mapFMap _ Unit = Unit mapFMap f (x :* xs) = f x :* mapFMap @g @f f xs -- type family All2 (c :: k -> l -> Constraint) (xs :: [k]) (ys :: [l]) :: Constraint where -- All2 c '[] '[] = () -- All2 c (x ': xs) (y ': ys) = (c x y, All2 c xs ys) -- All2 c '[] (y ': ys) = 'True ~ 'False -- All2 c (y ': ys) '[] = 'True ~ 'False -- | Flip at type level newtype F g t s = F {fromF :: g s t} -- | Tensor vector. (Elements in the indexing list are ignored.) type TV s t = NP (K (Tensor s t)) -- | Heterogeneous tensor vector with varying shapes and the same kind of elements type HTV t = NP (F T t) class Scnd' (x::(a,b)) instance Fun (Scnd') where type Ap Scnd' '(a,b) = b class Frst' (x::(a,b)) instance Fun (Frst') where type Ap Frst' '(a,b) = a type family Frst (x :: (a,b)) where Frst '(x,y) = x type family Scnd (x :: (a,b)) where Scnd '(x,y) = y type family Frst3 (x :: (a,b,c)) where Frst3 '(x,y,z) = x type family Scnd3 (x :: (a,b,c)) where Scnd3 '(x,y,z) = y type family Thrd3 (x :: (a,b,c)) where Thrd3 '(x,y,z) = z class (KnownShape (Scnd3 r), KnownTyp (Thrd3 r), KnownSymbol (Frst3 r)) => KnownPlaceholder r where placeHolderRef :: proxy r -> Ref String (Scnd3 r) (Thrd3 r) instance (KnownShape y, KnownTyp z, KnownSymbol x) => KnownPlaceholder '(x,y,z) where placeHolderRef _ = Ref (symbolVal (Proxy @x)) typeSShape typeSTyp class (KnownShape (Frst r), KnownTyp (Scnd r)) => KnownPair r instance (KnownShape x, KnownTyp y) => KnownPair '(x,y) newtype Uncurry g (s :: (a,b)) = Uncurry {fromUncurry :: g (Frst s) (Scnd s)} -- | Tensor vector heterogenous in types and shapes. type HHTV = NP (Uncurry T) type Placeholders = NP Placeholder type PH = (Symbol,Shape,Typ) newtype Placeholder (s :: PH) = PHT (T (Scnd3 s) (Thrd3 s)) hhead :: NP f (x ': xs) -> f x hhead (x :* _) = x htail :: NP f (x ': xs) -> NP f xs htail (_ :* xs) = xs htmap :: forall f ss t u. (forall s. Tensor s t -> Tensor (Ap f s) u) -> HTV t ss -> HTV u (Ap (FMap f) ss) htmap _ Unit = Unit htmap f (F x :* xs) = F (f x) :* htmap @f f xs -- htmap' :: forall f ss t u. All KnownShape ss => (forall s. KnownShape s => Tensor (Ap f s) t -> Tensor s u) -> SList ss -> HTV t (Ap (FMap f) ss) -> HTV u ss -- htmap' _ Unit Unit = Unit -- htmap' f ((:*) _ n)(F x :* xs) = F (f x) :* htmap' @f f n xs -- | Map a natural transformation hmap :: (forall x. f x -> g x) -> NP f xs -> NP g xs hmap _ Unit = Unit hmap f (x :* xs) = f x :* hmap f xs hTraverse :: forall f g xs m. Applicative m => (forall x. f x -> m (g x)) -> NP f xs -> m (NP g xs) hTraverse _ Unit = pure Unit hTraverse f (x :* xs) = do x' <- f x xs' <- hTraverse f xs return (x' :* xs') -- | Variant of hmap with a constraint hmapK :: forall k f g xs. All k xs => (forall x. k x => f x -> g x) -> NP f xs -> NP g xs hmapK _ Unit = Unit hmapK f (x :* xs) = f x :* hmapK @k f xs hMapToList :: forall k f xs a. All k xs => (forall x. k x => f x -> a) -> NP f xs -> [a] hMapToList f = htoList . hmapK @k (K . f) -- | If NP is in fact a vector, we have a "usual" map. kmap :: (a -> b) -> NP (K a) xs -> NP (K b) xs kmap _ Unit = Unit kmap f (K x :* xs) = K (f x) :* kmap f xs -- | If NP is in fact a tuple, we can apply a tuple of endomorphisms. (special case of <*>) hendo :: NP Endo xs -> HList xs -> HList xs hendo Unit Unit = Unit hendo (Endo f :* fs) (I x :* xs) = (I (f x) :* hendo fs xs) appSList, (.+.), happ :: NP f xs -> NP f ys -> NP f (xs ++ ys) happ Unit xs = xs happ (x :* xs) ys = x :* (happ xs ys) appSList = happ data Both f g x = Both {frst :: f x, scnd :: g x} bothFromPair :: (f x, g x) -> Both f g x bothFromPair (x,y) = (Both x y) bothToPair :: Both f g x -> (f x, g x) bothToPair (Both x y) = (x,y) hzip :: NP f xs -> NP g xs -> NP (Both f g) xs hzip = hzipWith Both hzipWith :: (forall x. f x -> g x -> h x) -> NP f xs -> NP g xs -> NP h xs hzipWith _ Unit Unit = Unit hzipWith f (x :* xs) (y :* ys) = f x y :* hzipWith f xs ys hfor :: forall k f xs m a. All k xs => Applicative m => NP f xs -> (forall x. k x => f x -> m a) -> m [a] hfor Unit _ = pure [] hfor (x :* xs) f = (:) <$> f x <*> hfor @k xs f htoList :: NP (K a) xs -> [a] htoList Unit = [] htoList (K x :* xs) = x : htoList xs hsplit' :: SPeano n -> NP f xs -> (NP f (Take n xs), NP f (Drop n xs)) hsplit' SZero xs = (Unit,xs) hsplit' (SSucc _n) Unit = (Unit,Unit) hsplit' (SSucc n) (x :* xs) = case hsplit' n xs of (l,r) -> (x :* l,r) hsplit :: forall xs ys f. KnownLen xs => NP f (xs++ys) -> (NP f xs, NP f ys) hsplit xys = splitApp @xs @ys (hsplit' (shapePeano @xs) xys) splitApp' :: forall ys xs k. SList xs -> ((Take (PeanoLength xs) (xs ++ ys) ~ xs, Drop (PeanoLength xs) (xs ++ ys) ~ ys) => k) -> k splitApp' Unit k = k splitApp' ((:*) _ n) k = splitApp' @ys n k splitApp :: forall xs ys k. KnownLen xs => ((Take (PeanoLength xs) (xs ++ ys) ~ xs, Drop (PeanoLength xs) (xs ++ ys) ~ ys) => k) -> k splitApp = splitApp' @ys (typeSList @xs) hsnoc :: NP f xs -> f x -> NP f (xs ++ '[x]) hsnoc xs x = happ xs (x :* Unit) data Peano = Zero | Succ Peano -- TODO: type Peano = '[()] (And then SPeano = NP) ? axis0 :: Axis 'Zero (x ': xs) axis0 = AxZero axis1 :: Axis ('Succ 'Zero) (x0 ': (x1 ': xs)) axis1 = AxSucc axis0 axis2 :: Axis ('Succ ('Succ 'Zero)) (x0 ': (x1 ': (x2 ': xs))) axis2 = AxSucc axis1 axis3 :: Axis ('Succ ('Succ ('Succ 'Zero))) (x0 ': (x1 ': (x2 ': (x3 ': xs)))) axis3 = AxSucc axis2 data Axis n xs where AxZero :: Axis 'Zero (x ': xs) AxSucc :: Axis n xs -> Axis ('Succ n) (x ': xs) axisInt :: Axis n xs -> Integer axisInt AxZero = 0 axisInt (AxSucc n) = 1 + axisInt n sPeanoInt :: SPeano n -> Integer sPeanoInt (SSucc n) = 1 + sPeanoInt n sPeanoInt SZero = 0 type family PeanoNat (n::Peano) :: Nat where PeanoNat 'Zero = 0 PeanoNat ('Succ n) = PeanoNat n + 1 data SPeano n where SZero :: SPeano 'Zero SSucc :: SPeano n -> SPeano ('Succ n) class KnownPeano n where knownPeano :: SPeano n instance KnownPeano 'Zero where knownPeano = SZero instance KnownPeano n => KnownPeano ('Succ n) where knownPeano = SSucc knownPeano type family Take n xs where Take 'Zero xs = '[] Take ('Succ n) '[] = '[] Take ('Succ n) (x ': xs) = x ': Take n xs type family Drop n xs where Drop 'Zero xs = xs Drop _ '[] = '[] Drop ('Succ n) (x ': xs) = Drop n xs type family At n xs where At 'Zero (x ': xs) = x At ('Succ n) (x ': xs) = At n xs -- type family Drop n xs where -- Drop 'Zero xs = xs -- Drop _ '[] = '[] -- Drop ('Succ n) (x ': xs) = Drop n xs -- type family At n xs where -- At 'Zero (x ': xs) = x -- At ('Succ n) (x ': xs) = At n xs data Kind = Float | Cmplx | Int | Bool deriving (Show,Eq,Ord) data SKind (s::Kind) where SFloat :: SKind 'Float SCmplx :: SKind 'Cmplx SInt :: SKind 'Int SBool :: SKind 'Bool data NBits = B32 | B64 | B1 deriving (Show,Eq,Ord) data SNBits s where SB32 :: SNBits 'B32 SB64 :: SNBits 'B64 data Typ = Typ Kind NBits deriving (Eq,Ord) type family TypKind (t :: Typ) where TypKind ('Typ k b) = k type family TypBits (t :: Typ) where TypBits ('Typ k b) = b type KnownNumeric t = (NumericKind (TypKind t), KnownBits (TypBits t), t ~ 'Typ (TypKind t) (TypBits t)) type KnownFloat t = (TypKind t ~ 'Float, KnownBits (TypBits t), t ~ 'Typ 'Float (TypBits t)) type KnownAlgebraic t = (AlgebraicKind (TypKind t), KnownBits (TypBits t), t ~ 'Typ (TypKind t) (TypBits t)) class KnownKind t => NumericKind t where instance NumericKind 'Float instance NumericKind 'Cmplx instance NumericKind 'Int class NumericKind t => AlgebraicKind t where instance AlgebraicKind 'Float instance AlgebraicKind 'Cmplx kVal :: SKind t1 -> Kind kVal SFloat = Float kVal SInt = Int kVal SBool = Bool kVal SCmplx = Cmplx instance Eq (SKind t) where x == y = kVal x == kVal y instance Ord (SKind t) where compare x y = compare (kVal x) (kVal y) nbitsVal :: SNBits w -> NBits nbitsVal SB64 = B64 nbitsVal SB32 = B32 instance Eq (SNBits t) where x == y = nbitsVal x == nbitsVal y instance Ord (SNBits t) where compare x y = compare (nbitsVal x) (nbitsVal y) sTypTyp :: STyp t1 -> Typ sTypTyp (STyp k b Refl) = Typ (kVal k) (nbitsVal b) instance Eq (STyp t) where x == y = sTypTyp x == sTypTyp y instance Ord (STyp t) where compare x y = compare (sTypTyp x) (sTypTyp y) data STyp t where STyp :: SKind (TypKind t) -> SNBits (TypBits t) -> (t :~: 'Typ (TypKind t) (TypBits t)) -> STyp t type Flt t = 'Typ 'Float t type Float32 = 'Typ 'Float 'B32 type Complex32 = 'Typ 'Cmplx 'B32 type Int32 = 'Typ 'Int 'B32 type Int64 = 'Typ 'Int 'B64 type TFBool = 'Typ 'Bool 'B32 type Scalar t = T '[] t type Shape = [Nat] class (KnownLen s, All KnownNat s) => KnownShape s where instance KnownShape '[] instance (KnownNat x, KnownShape xs) => KnownShape (x ': xs) type KnownTyp t = (KnownBits (TypBits t), KnownKind (TypKind t), t ~ 'Typ (TypKind t) (TypBits t)) typeSTyp :: forall t. KnownTyp t => STyp t typeSTyp = STyp (kindVal @(TypKind t)) (bitsVal @(TypBits t)) Refl type family HaskType t where HaskType Float32 = Float HaskType ('Typ 'Float 'B64) = Double HaskType ('Typ 'Cmplx 'B32) = Complex Float HaskType ('Typ 'Cmplx 'B64) = Complex Double HaskType ('Typ 'Int 'B64) = Hask.Int64 HaskType ('Typ 'Int 'B32) = Hask.Int32 HaskType ('Typ 'Bool w) = Bool class KnownBits t where bitsVal :: SNBits t instance KnownBits 'B32 where bitsVal = SB32 instance KnownBits 'B64 where bitsVal = SB64 typVal :: forall t. KnownTyp t => Typ typVal = Typ (kVal k) (nbitsVal b) where k = kindVal @(TypKind t) b = bitsVal @(TypBits t) knownBits :: SNBits t -> (KnownBits t => Fractional (HaskType ('Typ 'Float t)) => Floating (HaskType ('Typ 'Float t)) => k) -> k knownBits SB32 k = k knownBits SB64 k = k knownKind :: SKind t -> (KnownKind t => k) -> k knownKind SFloat k = k knownKind SInt k = k knownKind SBool k = k knownKind SCmplx k = k knownTyp :: STyp t -> (KnownTyp t => k) -> k knownTyp (STyp k b Refl) r = knownKind k $ knownBits b r knownAlgebraic :: forall t k. KnownAlgebraic t => ((Fractional (HaskType t), Floating (HaskType t)) => k) -> k knownAlgebraic k = case kindVal @(TypKind t) of SFloat -> case bitsVal @(TypBits t) of SB32 -> k SB64 -> k SCmplx -> case bitsVal @(TypBits t) of SB32 -> k SB64 -> k _ -> error "KnownAlgebraic bug" knownNum :: forall t k. KnownNumeric t => (KnownTyp t => Num (HaskType t) => k) -> k knownNum k = case kindVal @(TypKind t) of SFloat -> case bitsVal @(TypBits t) of SB32 -> k SB64 -> k SCmplx -> case bitsVal @(TypBits t) of SB32 -> k SB64 -> k SBool -> error "KnownNumeric bug" SInt -> case bitsVal @(TypBits t) of SB32 -> k SB64 -> k class KnownKind t where kindVal :: SKind t instance KnownKind 'Bool where kindVal = SBool instance KnownKind 'Cmplx where kindVal = SCmplx instance KnownKind 'Float where kindVal = SFloat instance KnownKind 'Int where kindVal = SInt type SList = NP Proxy instance Ord (Sat KnownNat t) where compare x@Sat y@Sat = compare (natVal x) (natVal y) instance Eq (Sat KnownNat t) where x@Sat == y@Sat = (natVal x) == (natVal y) type SShape = NP (Sat KnownNat) instance Ord (SShape s) where compare x y = compare (shapeToList' x) (shapeToList' y) instance Eq (SShape s) where Unit == Unit = True ((:*) x xs) == ((:*) y ys) = x == y && xs == ys instance {-# OVERLAPPING #-} Show (SShape s) where show x = show (shapeToList' x) sListLength :: NP f s -> Integer sListLength Unit = 0 sListLength ((:*) _ s) = 1+sListLength s sListLen :: NP f s -> Int sListLen = fromIntegral . sListLength sListLenAsNat :: NP f s -> Sat KnownNat (Length s) sListLenAsNat Unit = Sat sListLenAsNat ((:*) _ s) = case sListLenAsNat s of Sat -> Sat type family PeanoLength xs :: Peano where PeanoLength '[] = 'Zero PeanoLength (x ': xs) = 'Succ (PeanoLength xs) withKnownNat :: forall k. Int -> (forall (n::Nat). KnownNat n => Proxy n -> k) -> k withKnownNat 0 f = f (Proxy @0) withKnownNat 1 f = f (Proxy @1) withKnownNat n f = withKnownNat (n `div` 2) (if n `mod` 2 == 0 then f2x else f2x1) where f2x,f2x1 :: forall (n::Nat). KnownNat n => Proxy n -> k f2x _ = f (Proxy @(n*2)) f2x1 _ = f (Proxy @(n*2+1)) -- Probably a GHC bug: -- withKnownNat'' :: forall k. Int -> (forall (n::Nat). KnownNat n => k) -> k -- withKnownNat'' 0 f = f @0 -- withKnownNat'' n f = withKnownNat'' (n-1) fsucc -- where fsucc :: forall (n::Nat). KnownNat n => k -- fsucc = f @(n+1) -- This also fails: -- appProxy :: forall (n::Nat) k. KnownNat n => Proxy n -> (forall (m::Nat). KnownNat m => k) -> k -- appProxy f _ = f @n -- withKnownNat :: forall k. Int -> (forall (n::Nat). KnownNat n => k) -> k -- withKnownNat n f = withKnownNat' n (\proxy -> appProxy proxy f) class KnownNat (Length s) => KnownLen s where shapePeano :: SPeano (PeanoLength s) typeSList :: SList s instance KnownLen '[] where shapePeano = SZero typeSList = Unit instance KnownLen xs => KnownLen (x ': xs) where shapePeano = SSucc (shapePeano @xs) typeSList = (:*) Proxy (typeSList @xs) listTypeLen :: forall xs. KnownLen xs => Integer listTypeLen = sListLength (typeSList @xs) typeSListProxy :: KnownLen xs => proxy xs -> SList xs typeSListProxy _ = typeSList sListProxy :: NP f xs -> Proxy xs sListProxy _ = Proxy knownNatVal :: forall x. Sat KnownNat x -> Integer knownNatVal Sat = natVal (Proxy @x) shapeToList' :: SShape s -> [Integer] shapeToList' Unit = [] shapeToList' ((:*) x xs) = knownNatVal x : shapeToList' xs shapeToList'' :: All KnownNat s => NP proxy s -> [Integer] shapeToList'' Unit = [] shapeToList'' ((:*) x xs) = natVal x : shapeToList'' xs shapeToList :: ∀(s::Shape). KnownShape s => [Integer] shapeToList = shapeToList'' (typeSList @ s) typeSShape :: forall s. KnownShape s => SShape s typeSShape = sListSShape (typeSList @s) proxySShape :: forall s. KnownShape s => Proxy s -> SShape s proxySShape _ = typeSShape sListSShape :: forall s. All KnownNat s => SList s -> SShape s sListSShape = allKnown' type None = 514229 -- fibonnaci prime. -- type None = 0 - 1 -- GHC does not like negative Nats. -- Using a maybe type would be a RPITA. -------------------------------- -- Generation Effects (TODO: move to other module) data VarInfo = forall s t. (KnownShape s, KnownTyp t) => VarInfo {varTrainable :: Bool, varRef :: Ref String s t, varInitial :: Maybe (T s t)} varName :: VarInfo -> String varName VarInfo {varRef = Ref {..}} = refName data GState = GState {nextVar :: Integer, -- ^ next free variable genRegularizers :: [Scalar Float32] -- ^ accumulated regularizers } initialGstate :: GState initialGstate = (GState {nextVar = 0 ,genRegularizers=[] }) data Gen a where GPId :: Gen Integer GPVariable :: forall (shape :: Shape) t. (KnownTyp t,KnownShape shape) => Bool -> String -> Maybe (T shape t) -> Gen (Ref String shape t) GPModify :: (KnownShape s,KnownTyp t) => Ref Int s t -> T s t -> Gen (T s t) GPReturn :: a -> Gen a GPState :: (GState -> (a,GState)) -> Gen a GPApp :: (Gen (a -> b)) -> Gen a -> Gen b GPBind :: Gen a -> (a -> Gen b) -> Gen b genGets :: (GState -> a) -> Gen a genGets f = GPState (\s -> (f s, s)) instance Applicative Gen where (<*>) = GPApp pure = GPReturn instance Monad Gen where (>>=) = GPBind instance Functor Gen where fmap f = (pure f <*>) -------------------------- -- Tensors -- | An indexing tensor in the format expected by GatherND type IndexTensor indexShape containerShape w = T (indexShape ++ '[Length containerShape]) ('Typ 'Int w) -- | Description of a random distribution data Distribution (s :: Shape) (t :: Typ) where -- | Each element is from a truncated normal distribution with given standard dev. TruncatedNormalD :: Float -> Distribution s ('Typ 'Float w) -- | Each element is from a uniform distribution with given bounds (low, high) UniformD :: Float -> Float -> Distribution s ('Typ 'Float w) OrthogonalD :: Distribution '[m,n] ('Typ 'Float w) data Ref r s t = Ref {refName :: r, refShape :: SShape s, refTyp :: STyp t} data NilOp s t where ExternalVar :: Ref String s t -> NilOp s t Variable :: Ref Int s t -> NilOp s t Constant :: HaskType t -> NilOp '[] t Range :: KnownBits w => Sat KnownNat n -> NilOp '[n] ('Typ 'Int w) data Catable s1 s2 t n = Catable (Sat KnownNat n) (T (s1 ++ (n ': s2)) t) -- deriving Show type Unique = Int data T (s :: Shape) (t :: Typ) where BroadcastT :: KnownTyp t => Maybe Unique -> Bool -> Sat KnownNat n -> SShape s -> T s t -> T (n ': s) t MapT :: KnownTyp t => Sat KnownNat n -> SShape s -> (T s t -> T r u) -> T (n ': s) t -> T (n ': r) u ZipT :: (KnownTyp t, KnownTyp u) => Sat KnownNat n -> SShape s -> SShape r -> (T s t -> T r u -> T q v) -> T (n ': s) t -> T (n ': r) u -> T (n ': q) v Zip3T :: (KnownTyp t, KnownTyp u, KnownTyp v) => Sat KnownNat n -> SShape s -> SShape r -> SShape q -> (T s t -> T r u -> T q v -> T p w) -> T (n ': s) t -> T (n ': r) u -> T (n ': q) v -> T (n ': p) w T :: NilOp s t -> T s t Noise :: Integer -> -- this is the unique noise identifier, preventing two different noises to ever be re-shared. SShape s0 -> SShape s1 -> Distribution s1 t -> T (s0 ++ s1) t BinOp :: (KnownTyp t,KnownTyp u) => BinOp s1 t s2 u s3 v -> SShape s0 -> SShape s1 -> STyp t -> SShape s2 -> STyp u -> T (s0 ++ s1) t -> T (s0 ++ s2) u -> T (s0 ++ s3) v UnOp :: KnownTyp t => UnOp s1 t s2 u -> SShape s0 -> T (s0 ++ s1) t -> T (s0 ++ s2) u Unbroadcast :: Sat KnownNat n -> Unique -> T (n ': s) t -> T s t DirectBroadcast :: SShape s0 -> NP proxy' s1 -> SShape s2 -> NP proxy' s3 -> T (s0 ++ s2) t -> T (s0 ++ (s1 ++ (s2 ++ s3))) t ReshapeFrom :: Product s ~ Product s0 => SShape s0 -> T s0 t -> T s t Transpose :: SShape s0 -> Permutation s0 s -> T s0 t -> T s t Concat :: SShape s0 -> SShape s1 -> NP (Catable s0 s1 t) ns -> T (s0 ++ (Sum ns ': s1)) t Gather :: KnownTyp ('Typ 'Int w) => SShape indexShape -> SShape s0 -> Sat KnownNat m -> SShape s1 -> T (s0 ++ (m ': s1)) t -> T (s0 ++ indexShape) ('Typ 'Int w) -> T (s0 ++ indexShape ++ s1) t GatherND :: KnownTyp ('Typ 'Int w) => SShape containerShape -> SShape elementShape -> SShape indexShape -> T (containerShape ++ elementShape) t -> IndexTensor indexShape containerShape w -> T (indexShape ++ elementShape) t MatMul :: forall s m n o t. KnownNumeric t => SShape s -> Sat KnownNat n -> Sat KnownNat o -> Sat KnownNat m -> T (s ++ '[n,o]) t -> T (s ++ [o,m]) t -> T (s ++ [n,m]) t Where :: T s TFBool -> T s t -> T s t -> T s t If :: Scalar TFBool -> T s t -> T s t -> T s t Convolution :: KnownAlgebraic t => Sat KnownNat bs -> Sat KnownNat inChannels -> Sat KnownNat outChannels -> SShape filterSpatialShape -> SShape s -> T (bs ': s ++ '[inChannels]) t -- input tensor (batched) -> T (filterSpatialShape ++ '[inChannels,outChannels]) t -- filters -> T (bs ': s ++ '[outChannels]) t Pool :: Length outSpatial ~ Length window => Sat KnownNat bs -> SShape window -> PoolingType -> Sat KnownNat numChannels -> SShape outSpatial -> T (bs ': ZipWithMulShapes window outSpatial ++ '[numChannels]) t -> T (bs ': outSpatial ++ '[numChannels]) t Softmax :: Sat KnownNat bs -> Sat KnownNat n -> T '[bs,n] (Flt w) -> T '[bs,n] (Flt w) -- yes, softmax is shape-fixed: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/softmax -- instance Show Unique where -- show _ = "" -- deriving instance (Show (T s t)) type family ZipWithMulShapes (xs::Shape) (xy::Shape) :: Shape type instance ZipWithMulShapes (x ': xs) (y ': ys) = x*y ': ZipWithMulShapes xs ys type instance ZipWithMulShapes '[] _ = '[] type instance ZipWithMulShapes _ '[] = '[] satMul :: forall n m. Sat KnownNat n -> Sat KnownNat m -> Sat KnownNat (n*m) satMul Sat Sat = Sat satProd :: SShape s -> Sat KnownNat (Product s) satProd Unit = natSat @1 satProd (x :* xs) = satMul x (satProd xs) satAdd :: forall n m. Sat KnownNat n -> Sat KnownNat m -> Sat KnownNat (n+m) satAdd Sat Sat = Sat zipWithMulSShapes :: SShape xs -> SShape ys -> SShape (ZipWithMulShapes xs ys) zipWithMulSShapes Unit _ = Unit zipWithMulSShapes _ Unit = Unit zipWithMulSShapes ((:*) x xs) ((:*) y ys) = (:*) (satMul x y) (zipWithMulSShapes xs ys) data PoolingType = MaxPool | AvgPool deriving Show type Tensor shape = T shape data ReduceOp = Mean | Max | Min | Sum data Axis1Op s1 t s2 u where ReverseT :: Sat KnownNat n -> Axis1Op '[n] t '[n] t ArgMax :: KnownNumeric t => Sat KnownNat n -> Axis1Op '[n] t '[] ('Typ 'Int b) OneHot :: KnownNumeric t => Sat KnownNat n -> Axis1Op '[] ('Typ 'Int b) '[n] t ReduceOp :: KnownNumeric t => Sat KnownNat n -> ReduceOp -> Axis1Op '[n] t '[] t SliceOp :: forall m n t proxy. proxy m -> Sat KnownNat n -> Integer -> Integer -> Axis1Op '[n] t '[m] t AccessOp :: forall n t. Sat KnownNat n -> Integer -> Axis1Op '[n] t '[] t data Float1Op = ClipByValue Float Float | Tanh | Sin | Exp | Sigmoid | HardSigmoid | Relu | Floor | Round | Cos | Log | Asin | Acos | Sinh | Cosh | Asinh | Acosh | Atan | Atanh | Sqrt deriving Show data Num1Op = Square | Negate | Abs | Sign deriving Show data Side = Upper | Lower data UnOp (s1 :: Shape) (t :: Typ) (s2 :: Shape) (u :: Typ) where ZeroTriangle :: KnownNumeric t => Sat KnownNat n -> Side -> Integer -> UnOp '[n,n] t '[n,n] t -- https://numpy.org/doc/1.16/reference/generated/numpy.tril.html ExpM :: KnownNumeric t => Sat KnownNat n -> UnOp '[n,n] t '[n,n] t Diag :: Sat KnownNat n -> UnOp '[n] t '[n,n] t StopGradient :: UnOp '[] t '[] t Cast :: UnOp '[] t '[] u Conjugate :: UnOp '[] ('Typ 'Cmplx w) '[] ('Typ 'Cmplx w) RealPart :: UnOp '[] ('Typ 'Cmplx w) '[] ('Typ 'Float w) Num1Op :: KnownNumeric t => Num1Op -> UnOp '[] t '[] t Float1Op :: Float1Op -> UnOp '[] (Flt w) '[] (Flt w) Axis1Op :: SShape s -> Axis1Op s1 t s2 u -> UnOp (s1 ++ s) t (s2 ++ s) u -- deriving Show data CompOp = Less | Greater | LessOrEqual | GreaterOrEqual data LogicOp = And | Or data Simple2Op t u where Divide :: KnownAlgebraic t => Simple2Op t t IntegerDiv :: Simple2Op ('Typ 'Int w) ('Typ 'Int w) Equal :: KnownTyp t => Simple2Op t TFBool Subtract :: KnownNumeric t => Simple2Op t t Multiply :: KnownNumeric t => Simple2Op t t Add :: KnownNumeric t => Simple2Op t t Minimum :: KnownNumeric t => Simple2Op t t Maximum :: KnownNumeric t => Simple2Op t t FloorMod :: KnownNumeric t => Simple2Op t t Comparision :: KnownNumeric t => CompOp -> Simple2Op t TFBool Logic :: LogicOp -> Simple2Op TFBool TFBool MkComplex :: Simple2Op (Flt w) ('Typ 'Cmplx w) -- deriving instance Show (Simple2Op t u) data BinOp s1 t1 s2 t2 s3 t3 where Simple2Op :: Simple2Op t u -> BinOp '[] t '[] t '[] u SigmoidCrossEntropyWithLogits :: KnownFloat t => BinOp '[] t '[] t '[] t SoftmaxCrossEntropyWithLogits :: KnownFloat t => BinOp '[n] t '[n] t '[] t SparseSoftmaxCrossEntropyWithLogits :: BinOp '[] Int32 '[n] (Flt w) '[] (Flt w) -- deriving instance Show (BinOp a b c d e f) data Permutation (s :: [k]) (t :: [k]) where PermId :: Permutation s s PermSkip :: Permutation s t -> Permutation (n ': s) (n ': t) PermSwap :: Permutation (n ': m ': s) (m ': n ': s) PermTrans :: Permutation s t -> Permutation t u -> Permutation s u deriving instance Show (Permutation s t) class KnownTensors p where -- TODO: delete -- | traverse all the tensors contained in p. travTensor :: Applicative m => (forall s t. (KnownTyp t, KnownShape s) => String -> (T s t) -> m (T s t)) -> String -> p -> m p instance (KnownTyp t, KnownShape shape) => KnownTensors (T shape t) where travTensor f = f instance (All KnownPair ys) => KnownTensors (HHTV ys) where travTensor :: forall m. Applicative m => (forall s t'. (KnownTyp t', KnownShape s) => String -> T s t' -> m (T s t')) -> String -> HHTV ys -> m (HHTV ys) travTensor f s = ttr 0 where ttr :: forall xs. All KnownPair xs => Int -> HHTV xs -> m (HHTV xs) ttr _ Unit = pure Unit ttr n (Uncurry x :* xs) = do x' <- f (s <> "_" <> show n) x xs' <- ttr (n + 1) xs return (Uncurry x' :* xs') instance (KnownTyp t, All KnownShape ys) => KnownTensors (HTV t ys) where travTensor :: forall m. Applicative m => (forall s t'. (KnownTyp t', KnownShape s) => String -> T s t' -> m (T s t')) -> String -> (HTV t ys) -> m (HTV t ys) travTensor f s = ttr 0 where ttr :: forall xs. All KnownShape xs => Int -> HTV t xs -> m (HTV t xs) ttr _ Unit = pure Unit ttr n (F x :* xs) = do x' <- f (s <> "_" <> show n) x xs' <- ttr (n + 1) xs return (F x' :* xs') instance (KnownTensors p, KnownTensors q) => KnownTensors (p,q) where travTensor f s (x,y) = (,) <$> travTensor f (s<>"_fst") x <*> travTensor f (s<>"_snd") y instance (KnownTensors p1, KnownTensors p2, KnownTensors p3) => KnownTensors (p1,p2,p3) where travTensor f s (x,y,z) = (,,) <$> travTensor f (s<>"_1") x <*> travTensor f (s<>"_2") y <*> travTensor f (s<>"_3") z instance (KnownTensors p1, KnownTensors p2, KnownTensors p3, KnownTensors p4) => KnownTensors (p1,p2,p3,p4) where travTensor f s (x,y,z,w) = (,,,) <$> travTensor f (s<>"_1") x <*> travTensor f (s<>"_2") y <*> travTensor f (s<>"_3") z <*> travTensor f (s<>"_4") w class KnownTensors p => ParamWithDefault p where defaultInitializer :: Gen p ================================================ FILE: TypedFlow.hs ================================================ {-| Module : TypedFlow Description : Higher-Order Typed Binding to TensorFlow and Deep Learning Library Copyright : (c) Jean-Philippe Bernardy, 2017 License : LGPL-3 Maintainer : jean-philippe.bernardy@gu.se Stability : experimental This module re-exports all functions. -} module TypedFlow (module TypedFlow.Types ,module TypedFlow.TF ,module TypedFlow.Layers ,module TypedFlow.Learn ,module GHC.TypeLits) where import TypedFlow.TF import TypedFlow.Types import TypedFlow.Layers import TypedFlow.Learn import GHC.TypeLits ================================================ FILE: cabal.project ================================================ packages: ./typedflow.cabal ================================================ FILE: docs/HOT.org ================================================ #+TITLE: TypedFlow: The HOT parts #+AUTHOR: Jean-Philippe Bernardy, University of Gothenburg TensorFlow™ is an open source software library for numerical computation using data flow graphs. Nodes in the graph represent mathematical operations, while the graph edges represent the multidimensional data arrays (tensors) communicated between them. TensorFlow graphs can be efficiently evaluated on GPUs and is a popular choice to implement deep learning applications. TypedFlow is a higher-order and typed (HOT) frontend to tensorflow, written in (Glasgow) Haskell. In this talk I will: - briefly explain what TensorFlow is and how it applies to deep learning - recall the advantages of a HOT approach - expose some example programs written using TypedFlow - demonstrate how tensorflow functions can be given precise types, using GHC extensions - discuss some the difficulties of doing so * Machine learning in 45 seconds: - a vector of training inputs X :: [A] - a model f : (Θ × A) → ℝ⁺ Task: Given X, find θ such that f(θ,x) < ε, if x is considered similar to points in X. Commentary: Every point in X lie on a manyfold. We want to find what this manyfold is. (Interpolation problem.) * "Deep" learning in 45 seconds - "Deep" ≡ f is "complex" - So we must use a brute force method to compute θ: stochastic gradient descent (or variants thereof). - Typically, compute the gradient of f wrt. θ using AD. * Tensorflow - A (meta) programming language to define f. AD is builtin. (there is fineprint) - Restricted control flow (mostly, tensor-generalisations of +, *, -, /, ^) - Typically programmed using python (standard in scientific computing) - "Strongly typed". - but: "brodcasting" - but: running the metaprogram can be quite slow ~1 minute (so type errors can happen after 1 minute of loading the model --- and programs can do other things before ...) - but: types are typically not written as such. Given any two functions, weather they compose (and what the composition does) is a mistery unless one examines their code. * First-order culture https://github.com/fchollet/keras/blob/master/keras/layers/recurrent.py (search "class LSTM") * TypedFlow - An typed, higher-order frontend to tensorflow (basic tensor operations) - A library to construct neural networks - Generates python (yikes) * Typing tensors - tanh - matmul - concatT - repeatT - tile - convolution * Heterogeneous tensors type HTV * Example Higher-order stuff - mapT - rnn - withBypass - Attention model * Complete examples - mnist - seq2seq * GHC woes - see transposeV * Summary - Some NN building blocks are naturally higher-order. Taking an example (and simplifying) a recurrent neural network turns a tensor function into a function between lists (vectorslists) of tensors. - Functional programming is ideally suited to program complicated applications from building blocks. Example: an "Attention-model" is a thing where every step in a RNN adds a computation which depends on an external input. We can compose usual RNN cells with attention models in several ways. The state of the art is to reprogram all combinations by hand. - Typed APIs. Types can be used to check the tensor dimensions. Types catch a lof of errors, but they can also be used to *guide* the programming. Types are pretty much a necessity in the presence of HO functions. - TypedFlow is typically much closer to mathematical notation than python. Programs are short to write and easier to read. Standard building blocks can be swapped for custom versions quite easily. Examples - rnn stacking using "residual connections" instead of just stacking. - make it easy to share parameters between different components (example: if we do a style translation we may want to share the embedding layers between encoder and decoders parts) - Long game: integrate cutting edge ideas as they arrive with moderate effort. * FAQ - Why not Agda, Idris? A long term plan is to bypass python, so we'd want a "real" programming language for the programming bits that go around the TF program. ================================================ FILE: docs/Talk.org ================================================ #+TITLE: TypedFlow: A library for higher-order typed deep learning #+AUTHOR: Jean-Philippe Bernardy, University of Gothenburg TensorFlow is a library for numerical computation, with specific features for machine-learning such as gradient computation. It is perhaps the most popular backend for deep learning applications. TypedFlow a higher-order and typed (HOT) frontend to TensorFlow written in Haskell, and a library of neural-network layers and combinators. In this talk I will: - briefly recall what TensorFlow is and how it applies to deep learning - discuss the advantages of a HOT approach vs. plain TensorFlow - expose two use-cases: the standard MNIST example and a sequence-to-sequence network with attention model. Ideas: transparency, explainability * Machine learning in 45 seconds: - a vector of training inputs X::[A] - a model f : (Θ × A) → ℝ⁺ Task: Given X, find θ such that f(θ,x) < ε, if x is considered similar to points in X, and > ε otherwise. Commentary: Every point in X lie on a manifold. We want to find what this manyfold is. (Interpolation problem.) * "Deep" learning in 45 seconds - "Deep" ≡ f is "complicated" - So we must use a brute force method to compute θ: stochastic gradient descent (or variants thereof). - Typically, compute the gradient of f wrt. θ using AD. * Tensorflow - A (meta) programming language to define f. AD is builtin. (there is fineprint) - Restricted control flow (mostly, tensor-generalisations of +, *, -, /, ^, tanh, ...) - Typically programmed using python (standard in scientific computing) - "Strongly typed" - but: no abstraction over dimensions - but: "brodcasting" - but: running the metaprogram can be quite slow ~1 minute (so type errors can happen after 1 minute of loading the model --- and programs can do other things before ...) - but: types are typically not written as such. Given any two functions, weather they compose (and what the composition does) is a mistery unless one examines their code. - "map" has a surprising semantics (see below) * What is TypedFlow? - An typed, higher-order frontend to tensorflow (basic tensor operations) - A library to construct neural networks - Generates python * Why TypedFlow? Functional programming is ideally suited to program complicated applications from building blocks. - Notation - Types - HO * Deep Learning: The state of the art [[file:cards.jpg]] (Actually this has become worse!) * Notation Haskell is typically much closer to mathematical notation than python. Programs are short to write and easier to read. file:../TypedFlow/TF.hs::/⊕.*::/ * Why Types? Types can be used to check the tensor dimensions. - Types catch a lof of errors - but they can also be used to *guide* the programming. "Type holes" (see MNIST example) Types are pretty much a necessity to take advantage of HO functions. #+BEGIN_QUOTE Together with the absence of side effects, rich type systems enable to construct complex programs with a high degree of confidence: - types precisely abstract the intention of the programmer for each function, without any hidden side effect, and - provided that they match the contracts imposed by types, functions can be freely combined, using lazy evaluation and higher-order facilities, without risk of pernicious interference. #+END_QUOTE * Python, aka The Culture of First Order [[file:imperiallegion.jpg]] https://github.com/fchollet/keras/blob/master/keras/layers/recurrent.py (search "class LSTM") * Example 1: LSTM file:../TypedFlow/Layers/RNN.hs::/^lstm.*::/ * Example 2: Attention Example: an "Attention-model" is a model where every step in a RNN adds a computation which depends on an external input. We can compose usual RNN cells with attention models in several ways. The state of the art is to reprogram such combinations by hand. file:../TypedFlow/Layers/RNN.hs::/^attentiveWithFeedback.*::/ * Mapping tensors - Tensorflow's ~map~ spawns processes. This is (usually) quite a bad idea --- tensor operations are parallelized anyway (but not on several GPUs... the purpose of ~map~ apparently). - Most (but not all!) operations have so-called "broadcast semantics"; they can be (implicitly!) raised to tensors of higher dimensions. - file:../TypedFlow/Abstract.hs::/^protoBroadcast.*::/ - Note "gather" goes to "gather_nd" - Certain convolutions can't be broadcasted at all 😿 * Pretending that tensor operations are functional - They are EXCEPT that sharing is lost - Use the old trick of observable sharing. (Memoizing, etc.) * Long game - Integrate cutting edge DL ideas as they arrive with moderate effort. * MNIST file:../examples/mnist/MNIST.hs * Seq2Seq file:../examples/seq2seq/Seq2Seq.hs ================================================ FILE: examples/agreement/Aggr.hs ================================================ {-# LANGUAGE ApplicativeDo #-} {-# LANGUAGE ViewPatterns #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnicodeSyntax #-} import TypedFlow import TypedFlow.Python import qualified GHC.Int as GHC onFST :: (Tensor s1 t -> Tensor s t) -> HTV t '[s1, s'] -> HTV t '[s, s'] onFST f (VecPair h c) = (VecPair (f h) c) mkLSTM :: ∀ n x. KnownNat x => KnownNat n => String -> DropProb -> Gen (RnnCell Float32 '[ '[n], '[n]] (Tensor '[x] Float32) (Tensor '[n] Float32)) mkLSTM pName dropProb = do params <- parameterDefault pName drp1 <- mkDropout dropProb rdrp1 <- mkDropout dropProb return (timeDistribute drp1 .-. onStates (onFST rdrp1) (lstm params)) model :: forall (vocSize::Nat) (len::Nat). KnownNat len => KnownNat vocSize => Gen (T '[len] Int32 -> T '[len] Int32 -> ModelOutput Float32 '[len,vocSize] '[]) model = do embs <- parameterDefault "embs" let dropProb = DropProb 0.10 lstm1 <- mkLSTM @160 "w1" dropProb drp <- mkDropout dropProb w <- parameterDefault "dense" return $ \input gold -> do let masks = constant 1 ⊝ cast @Float32 (equal (constant padding) input) (_sFi,predictions) = simpleRnn (timeDistribute (embedding @12 @vocSize embs) .-. lstm1 .-. timeDistribute drp .-. timeDistribute (dense w)) (repeatT zeros, input) in timedCategorical masks predictions gold padding :: GHC.Int32 padding = 10 main :: IO () main = do generateFile "aggr.py" (compile @512 defaultOptions (model @12 @21)) putStrLn "done!" -- >>> main -- Parameters (total 134300): -- dense_bias: T [12] tf.float32 -- dense_w: T [160,12] tf.float32 -- w1_o_b: T [160] tf.float32 -- w1_o_w: T [172,160] tf.float32 -- w1_c_b: T [160] tf.float32 -- w1_c_w: T [172,160] tf.float32 -- w1_i_b: T [160] tf.float32 -- w1_i_w: T [172,160] tf.float32 -- w1_f_b: T [160] tf.float32 -- w1_f_w: T [172,160] tf.float32 -- embs: T [12,12] tf.float32 -- y: T [512,21] tf.int32 -- x: T [512,21] tf.int32 -- done! (|>) :: ∀ a b. a -> b -> (a, b) (|>) = (,) infixr |> -- Local Variables: -- dante-repl-command-line: ("nix-shell" ".styx/shell.nix" "--pure" "--run" "cabal repl") -- End: ================================================ FILE: examples/mnist/MNIST.hs ================================================ {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# LANGUAGE ApplicativeDo #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnicodeSyntax #-} {-# LANGUAGE NoStarIsType #-} module MNIST where import TypedFlow import TypedFlow.Python atShape :: forall s t. T s t -> T s t atShape x = x mnist :: Gen (Model '[784] Float32 '[10] '[10] '[] Float32) mnist = do filters1 <- parameterDefault "f1" filters2 <- parameterDefault "f2" w1 <- parameterDefault "w1" w2 <- parameterDefault "w2" return $ \input gold -> let nn = dense @10 w2 . relu . dense @1024 w1 . reshape @'[7 * 7 * 64] . maxPool2D @2 @2 . relu . conv @64 @'[5,5] filters2 . maxPool2D @2 @2 . atShape @'[28,28,32] . relu . conv @32 @'[5,5] filters1 . reshape @'[28,28,1] logits = nn input in categoricalDistribution logits gold main :: IO () main = do generateFile "mnist_model.py" (compile @100 defaultOptions mnist) putStrLn "done!" -- >>> main -- Parameters (total 3274634): -- f1_filters: T [5, 5, 1, 32] tf.float32 -- f1_biases: T [32] tf.float32 -- f2_filters: T [5, 5, 32, 64] tf.float32 -- f2_biases: T [64] tf.float32 -- w1_w: T [3136, 1024] tf.float32 -- w1_bias: T [1024] tf.float32 -- w2_w: T [1024, 10] tf.float32 -- w2_bias: T [10] tf.float32 -- done! -- Local Variables: -- dante-repl-command-line: ("nix-shell" ".styx/shell.nix" "--pure" "--run" "cabal repl") -- End: ================================================ FILE: examples/mnist/Makefile ================================================ test: mnist_model.py main.py nix-shell ../seq2seq/shell.nix --run "python main.py" mnist_model.py: MNIST.hs nix-shell ../../.styx/shell.nix --run "ghci -i../.. MNIST.hs -e main" ================================================ FILE: examples/mnist/main.py ================================================ import sys sys.path.append('../..') # so we can see the rts. import typedflow_rts as tyf import tensorflow as tf import numpy as np from mnist_model import mkModel,runModel import os # comment out if you don't have CUDA tyf.cuda_use_one_free_device() optimizer = tf.keras.optimizers.Adam(1e-4) # import tfds.image_classification.MNIST as mnist # need to package tfds def train_generator(batch_size): for _ in range(1000): # (x,y) = mnist.batch(100) yield {"x":np.zeros((100,784), dtype=np.float32), # FIXME "y":np.zeros((100,10), dtype=np.float32) } model = mkModel() tyf.train(optimizer,model,runModel,train_generator) ================================================ FILE: examples/mnist/mnist_model.py ================================================ import tensorflow as tf def mkModel(): #shape: [25, 32] var10000=tf.random.uniform([25, 32], minval=-0.32444283, maxval=0.32444283, dtype=tf.float32) # 0 #shape: [5, 5, 1, 32] var10001=tf.reshape(var10000, [5, 5, 1, 32]) var10002=tf.Variable(name="f1_filters", trainable=True, initial_value=var10001) #shape: [] var10003=tf.constant(0.1, shape=[], dtype=tf.float32) #shape: [32] var10004=ERROR:BroadcastT(var10003) #shape: [32] var10005=tf.reshape(var10004, [32]) var10006=tf.Variable(name="f1_biases", trainable=True, initial_value=var10005) #shape: [800, 64] var10007=tf.random.uniform([800, 64], minval=-8.3333336e-2, maxval=8.3333336e-2, dtype=tf.float32) # 1 #shape: [5, 5, 32, 64] var10008=tf.reshape(var10007, [5, 5, 32, 64]) var10009=tf.Variable(name="f2_filters", trainable=True, initial_value=var10008) #shape: [64] var10010=tf.reshape(var10004, [64]) var10011=tf.Variable(name="f2_biases", trainable=True, initial_value=var10010) #shape: [3136, 1024] var10012=tf.random.uniform([3136, 1024], minval=-3.7977725e-2, maxval=3.7977725e-2, dtype=tf.float32) # 2 var10013=tf.Variable(name="w1_w", trainable=True, initial_value=var10012) #shape: [1024] var10014=tf.random.truncated_normal([1024], stddev=0.1, dtype=tf.float32) # 3 var10015=tf.Variable(name="w1_bias", trainable=True, initial_value=var10014) #shape: [1024, 10] var10016=tf.random.uniform([1024, 10], minval=-7.61755e-2, maxval=7.61755e-2, dtype=tf.float32) # 4 var10017=tf.Variable(name="w2_w", trainable=True, initial_value=var10016) #shape: [10] var10018=tf.random.truncated_normal([10], stddev=0.1, dtype=tf.float32) # 5 var10019=tf.Variable(name="w2_bias", trainable=True, initial_value=var10018) return {"batch_size":100, "parameters":[ var10002 , var10006 , var10009 , var10011 , var10013 , var10015 , var10017 , var10019 ], "paramsdict":{"f1_filters":var10002, "f1_biases":var10006, "f2_filters":var10009, "f2_biases":var10011, "w1_w":var10013, "w1_bias":var10015, "w2_w":var10017, "w2_bias":var10019}} @tf.function def runModel_fn(training_placeholder, f1_filters, f1_biases, f2_filters, f2_biases, w1_w, w1_bias, w2_w, w2_bias, x, y): #shape: [100, 10] var10020=y #shape: [100, 784] var10021=x #shape: [100, 28, 28, 1] var10022=tf.reshape(var10021, [100, 28, 28, 1]) #shape: [5, 5, 1, 32] var10023=f1_filters #shape: [100, 28, 28, 32] var10024=tf.nn.convolution(var10022, var10023, padding="SAME", data_format="NHWC") #shape: [100, 784, 32] var10025=tf.reshape(var10024, [100, 784, 32]) #shape: [32] var10026=f1_biases #shape: [784, 32] var10027=tf.broadcast_to(tf.reshape(var10026, [1, 32]), [784, 32]) #shape: [100, 784, 32] var10028=tf.broadcast_to(tf.reshape(var10027, [1, 784, 32]), [100, 784, 32]) #shape: [100, 784, 32] var10029=tf.add(var10025, var10028) #shape: [100, 28, 28, 32] var10030=tf.reshape(var10029, [100, 28, 28, 32]) #shape: [100, 28, 28, 32] var10031=tf.nn.relu(var10030) #shape: [100, 28, 28, 32] var10032=tf.reshape(var10031, [100, 28, 28, 32]) #shape: [100, 14, 14, 32] var10033=tf.nn.pool(var10032, [2, 2], "MAX", strides=[2, 2], padding="SAME") #shape: [100, 14, 14, 32] var10034=tf.reshape(var10033, [100, 14, 14, 32]) #shape: [5, 5, 32, 64] var10035=f2_filters #shape: [100, 14, 14, 64] var10036=tf.nn.convolution(var10034, var10035, padding="SAME", data_format="NHWC") #shape: [100, 196, 64] var10037=tf.reshape(var10036, [100, 196, 64]) #shape: [64] var10038=f2_biases #shape: [196, 64] var10039=tf.broadcast_to(tf.reshape(var10038, [1, 64]), [196, 64]) #shape: [100, 196, 64] var10040=tf.broadcast_to(tf.reshape(var10039, [1, 196, 64]), [100, 196, 64]) #shape: [100, 196, 64] var10041=tf.add(var10037, var10040) #shape: [100, 14, 14, 64] var10042=tf.reshape(var10041, [100, 14, 14, 64]) #shape: [100, 14, 14, 64] var10043=tf.nn.relu(var10042) #shape: [100, 14, 14, 64] var10044=tf.reshape(var10043, [100, 14, 14, 64]) #shape: [100, 7, 7, 64] var10045=tf.nn.pool(var10044, [2, 2], "MAX", strides=[2, 2], padding="SAME") #shape: [100, 3136] var10046=tf.reshape(var10045, [100, 3136]) #shape: [3136, 1024] var10047=w1_w #shape: [100, 1024] var10048=tf.matmul(var10046, var10047) #shape: [100, 1024] var10049=tf.reshape(var10048, [100, 1024]) #shape: [1024] var10050=w1_bias #shape: [100, 1024] var10051=tf.broadcast_to(tf.reshape(var10050, [1, 1024]), [100, 1024]) #shape: [100, 1024] var10052=tf.add(var10049, var10051) #shape: [100, 1024] var10053=tf.nn.relu(var10052) #shape: [100, 1024] var10054=tf.reshape(var10053, [100, 1024]) #shape: [1024, 10] var10055=w2_w #shape: [100, 10] var10056=tf.matmul(var10054, var10055) #shape: [100, 10] var10057=tf.reshape(var10056, [100, 10]) #shape: [10] var10058=w2_bias #shape: [100, 10] var10059=tf.broadcast_to(tf.reshape(var10058, [1, 10]), [100, 10]) #shape: [100, 10] var10060=tf.add(var10057, var10059) #shape: [100] var10061=tf.nn.softmax_cross_entropy_with_logits(labels=var10020, logits=var10060) #shape: [100] var10062=tf.reshape(var10061, [100]) #shape: [] var10063=tf.reduce_mean(var10062, axis=0) #shape: [] var10064=tf.constant(0.0, shape=[], dtype=tf.float32) #shape: [1] var10065=tf.broadcast_to(tf.reshape(var10064, [1]), [1]) #shape: [] var10066=tf.reshape(var10065, []) #shape: [] var10067=tf.add(var10063, var10066) #shape: [100] var10068=tf.argmax(var10060, axis=1, output_type=tf.int32) #shape: [100] var10069=tf.argmax(var10020, axis=1, output_type=tf.int32) #shape: [100] var10070=tf.equal(var10068, var10069) #shape: [100] var10071=tf.cast(var10070, tf.float32) #shape: [100] var10072=tf.reshape(var10071, [100]) #shape: [] var10073=tf.reduce_mean(var10072, axis=0) #shape: [100, 10] var10074=tf.reshape(var10060, [100, 10]) #shape: [100, 10] var10075=tf.nn.softmax(var10074, axis=1) #shape: [100, 10] var10076=tf.reshape(var10075, [100, 10]) return {"loss":var10067, "accuracy":var10073, "y_":var10076} runModel = {"function":runModel_fn, "batched":True, "placeholders":{"x":{"shape":[100, 784], "dtype":tf.float32}, "y":{"shape":[100, 10], "dtype":tf.float32}}} ================================================ FILE: examples/seq2seq/GenTr.hs ================================================ import Control.Applicative import Test.QuickCheck.Gen import Data.List import Data.Array data Abs a = Bin a (Abs a) (Abs a) | Leaf a deriving Show type Method a = a -> [a] -> [a] -> [a] parens :: String -> String parens xs = "(" ++ xs ++ ")" preorder :: Char -> [Char] -> [Char] -> String preorder x l r = (x : l ++ r) postorder :: Char -> [Char] -> [Char] -> String postorder x l r = (l ++ r ++ [x]) reversePO :: Char -> [Char] -> [Char] -> String reversePO x l r = (x : r ++ l) linearize _ (Leaf x) = [x] linearize m (Bin x l r) = parens (m x (lin l) (lin r)) where lin = linearize m mkMethods :: Eq a => [(a->Bool,Method a)] -> Method a mkMethods ms x = case find (\(p,_) -> p x) ms of Just (_,m) -> m x Nothing -> error "no applicable linearization method" linPO :: Abs Char -> [Char] linPO = linearize (mkMethods [(const True,preorder)]) lin1 :: Abs Char -> [Char] lin1 = linearize (mkMethods [(\x -> x < '3',reversePO),(const True,preorder)]) ex :: Abs Char ex = Bin 'a' (Bin '1' (Leaf 'b') (Leaf 'c')) (Leaf 'd') guard :: Alternative f => Bool -> f a -> f a guard True x = x guard False _ = empty arb :: Gen (Abs Char) arb = sized $ \n -> do oneof (take (max 1 n) [(Leaf <$> elements ['a'..'e']) ,resize (n-1) (Bin <$> elements ['0'..'4'] <*> arb <*> arb)]) arbOkSize :: Gen (Abs Char) arbOkSize = do x <- resize 6 arb let xx = linPO x if (length xx > 2 && length xx < 22) then return x else arbOkSize mySample :: Int -> IO [Abs Char] mySample n = generate (sequence $ replicate n arbOkSize) showEx :: Abs Char -> String showEx x = linPO x ++ "\t" ++ lin1 x test :: IO () test = mapM_ putStrLn . map showEx =<< mySample 10 main :: IO () main = writeFile "synthtrees.txt" . unlines . map showEx =<< mySample 100000 ================================================ FILE: examples/seq2seq/Makefile ================================================ test: s2s.py synthtrees.txt main.py nix-shell --run "python main.py" s2s.py: Seq2Seq.hs nix-shell ../../.styx/shell.nix --run "ghci -i../.. Seq2Seq.hs -e main" synthtrees.txt: GenTr.hs nix-shell ../../.styx/shell.nix --run "ghc --make GenTr" ./GenTr ================================================ FILE: examples/seq2seq/Seq2Seq.hs ================================================ {-# LANGUAGE AllowAmbiguousTypes #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UnicodeSyntax #-} module Main where import TypedFlow import TypedFlow.Python mkLSTM :: ∀ n x w. KnownNat x => KnownNat n => KnownBits w => String -> Gen (RnnCell w '[ '[n], '[n]] (Tensor '[x] (Flt w)) (Tensor '[n] (Flt w))) mkLSTM pName = do params <- parameterDefault pName drp1 <- mkDropout (DropProb 0.05) rdrp1 <- mkDropouts (DropProb 0.05) return (timeDistribute drp1 .-. onStates rdrp1 (lstm params)) encoder :: forall (lstmSize :: Nat) (vocSize :: Nat) (n :: Nat) w. KnownNat lstmSize => KnownNat vocSize => (KnownNat n) => KnownBits w => String -> Gen ( T '[] Int32 -- length -> Tensor '[n] Int32 -> ((HTV (Flt w) '[ '[lstmSize], '[lstmSize] ], Tensor '[n, lstmSize] (Flt w)))) encoder prefix = do embs <- parameterDefault (prefix++"embs") lstm1 <- mkLSTM (prefix++"lstm1") return $ \len input -> runRnn (iterateWithCull len (timeDistribute (embedding @vocSize @vocSize embs) .-. lstm1)) (repeatT zeros, input) decoder :: forall (lstmSize :: Nat) (n :: Nat) (outVocabSize :: Nat) (d::Nat) w. KnownNat lstmSize => KnownNat d => (KnownNat outVocabSize, KnownNat n) => KnownBits w => String -> Gen ( T '[] Int32 -- ^ length -> T '[n, d] (Flt w) -- todo: consider a larger size for the output string -> HTV (Flt w) '[ '[lstmSize], '[lstmSize] ] -> Tensor '[n] Int32 -> Tensor '[n, outVocabSize] (Flt w)) decoder prefix = do -- note: for an intra-language translation the embeddings can be shared easily. projs <- parameterDefault (prefix++"proj") lstm1 <- mkLSTM (prefix++"lstm1") embs <- parameterDefault "embs" w1 <- parameter' (prefix++"att1") =<< glorotUniform return $ \ lens hs thoughtVectors targetInput -> let attn = uniformAttn (multiplicativeScoring w1) lens hs -- NOTE: attention on the left-part of the input. (_sFinal,outFinal) = simpleRnn ((timeDistribute (embedding @outVocabSize @outVocabSize embs) .-. attentiveWithFeedback attn lstm1 .-. timeDistribute (dense projs))) ((F zeros :* thoughtVectors), targetInput) in outFinal seq2seq :: forall (vocSize :: Nat) (n :: Nat). KnownNat vocSize => (KnownNat n) => Gen (Placeholders '[ '("tgt_weights", '[n], Float32), '("src_in", '[n], Int32), '("src_len", '[], Int32), '("tgt_in", '[n], Int32), '("tgt_out", '[n], Int32)] -> ModelOutput Float32 '[n, vocSize] '[]) seq2seq = do enc <- encoder @256 @vocSize "enc" dec <- decoder "dec" return $ \(PHT masks :* PHT input :* PHT inputLen :* PHT tgtIn :* PHT tgtOut :* Unit) -> let (VecPair t1 t2,h) = enc inputLen input y_ = dec inputLen h (VecPair t1 t2) tgtIn in timedCategorical masks y_ tgtOut main :: IO () main = generateFile "s2s.py" (compileGen @256 defaultOptions {maxGradientNorm = Just 5} (stateless <$> seq2seq @15 @22)) -- >>> main -- Parameters (total 889041): -- decatt1: T [256,256] tf.float32 -- embs: T [15,15] tf.float32 -- declstm1_o_bias: T [256] tf.float32 -- declstm1_o_w: T [527,256] tf.float32 -- declstm1_c_bias: T [256] tf.float32 -- declstm1_c_w: T [527,256] tf.float32 -- declstm1_i_bias: T [256] tf.float32 -- declstm1_i_w: T [527,256] tf.float32 -- declstm1_f_bias: T [256] tf.float32 -- declstm1_f_w: T [527,256] tf.float32 -- decproj_bias: T [15] tf.float32 -- decproj_w: T [256,15] tf.float32 -- enclstm1_o_bias: T [256] tf.float32 -- enclstm1_o_w: T [271,256] tf.float32 -- enclstm1_c_bias: T [256] tf.float32 -- enclstm1_c_w: T [271,256] tf.float32 -- enclstm1_i_bias: T [256] tf.float32 -- enclstm1_i_w: T [271,256] tf.float32 -- enclstm1_f_bias: T [256] tf.float32 -- enclstm1_f_w: T [271,256] tf.float32 -- encembs: T [15,15] tf.float32 -- Local Variables: -- dante-repl-command-line: ("nix-shell" ".styx/shell.nix" "--pure" "--run" "cabal repl") -- End: ================================================ FILE: examples/seq2seq/main.py ================================================ import sys sys.path.append('../..') # so we can see the rts. import typedflow_rts as tyf import tensorflow as tf import numpy as np from s2s import mkModel import os import math import random # comment out if you don't have CUDA tyf.cuda_use_one_free_device() chars = sorted(list("()01234abcde^$ ")) print('total chars:', len(chars)) char_indices = dict((c, i) for i, c in enumerate(chars)) indices_char = dict((i, c) for i, c in enumerate(chars)) MAXLEN = 22 def pad(ws): return (ws + ' '*(MAXLEN - len(ws))) def encode(s): # print ("proun", s) return np.array([char_indices[c] for c in s]) def decode(s): return "".join([indices_char[c] for c in list(s)]) def pad_right(sentence): return (MAXLEN - len(sentence)) * " " + sentence def pad_left(sentence): return sentence + (MAXLEN - len(sentence)) * " " def source_input_conversion(s): return encode(pad_left(s)) def target_input_conversion(sentence): return encode(pad_left("^"+sentence)) def target_output_conversion(sentence): return encode(pad_left(sentence+"$")) def sentence_target_weights(sentence): l = len(sentence) w = (l + 1) * [1] + (MAXLEN - (l + 1)) * [0] return np.array(w) def map(f,l): return [f(x) for x in l] def make_examples(l): (l1,l2) = zip(*l) return {"src_in":map(source_input_conversion,l1), "src_len":map(len,l1), "tgt_in":map(target_input_conversion,l2), "tgt_out":map(target_output_conversion,l2), "tgt_weights":map(sentence_target_weights,l2)} def s2s_generator(src_len,src_in,tgt_in,tgt_out,tgt_weights): def gen(bs): for i in range(0, bs*(len(src_in)//bs), bs): # print ({"src_len":src_len[i:i+bs], "src_in":src_in[i:i+bs], "tgt_in":tgt_in[i:i+bs], "tgt_out":tgt_out[i:i+bs], "tgt_weights":tgt_weights[i:i+bs]}) yield {"src_len":src_len[i:i+bs], "src_in":src_in[i:i+bs], "tgt_in":tgt_in[i:i+bs], "tgt_out":tgt_out[i:i+bs], "tgt_weights":tgt_weights[i:i+bs]} return gen def my_sample(l,n): return list(random.sample(l,min(n,len(l)))) print("Reading sentences...") all_sentences = [l.strip().split("\t") for l in open("synthtrees.txt").readlines()] val_set = make_examples(all_sentences[:2000]) train_set = make_examples(all_sentences[2000:]) print("Loading model") model = mkModel(tf.train.AdamOptimizer()) sess = tf.Session() saver = tf.train.Saver() def printer(x): (p,y,h) = x print("Prob", p, decode(y),h) def translate(s): r = tyf.beam_translate(sess,model,14, source_input_conversion(s), len(s), char_indices["^"], char_indices["$"], printer) for x in r: printer(x) def translate_cb(values): if values["epoch"] % 10 == 0: save_path = saver.save(sess, "model.ckpt") translate("(1(3cb)b)") print ("Desired:", "(1b(3cb))") return False tyf.initialize_params(sess,model) train_stats = tyf.train(sess, model, s2s_generator(**train_set), valid_generator = s2s_generator(**val_set), epochs=5000, callbacks=[tyf.StopWhenAccurate(.01), translate_cb]) translate("(1(3cb)b)") translate("(1(2c(3e(4(1cb)b)))c)") ================================================ FILE: examples/seq2seq/shell.nix ================================================ { bootstrap ? import {} }: let nixpkgs_source = fetchTarball https://github.com/NixOS/nixpkgs/archive/nixos-20.03.tar.gz; # nixpkgs_source = fetchTarball https://github.com/NixOS/nixpkgs/archive/4cf0b6ba5d5ab5eb20a88449e0612f4dad8e4c29.tar.gz; # nixpkgs_source = bootstrap.fetchFromGitHub { # for safety of checking the hash # owner = "jyp"; # repo = "nixpkgs"; # rev = "6b911c2d99ad116fca338fc26de86b8859079322"; # sha256 = "1bhwjkynya653mvpc4wwqks6kxnc06gyw6sbpwp8dbyr444ms4bd"; # }; # nixpkgs_source = ~/repo/nixpkgs; in with (import nixpkgs_source {}).pkgs; let py = (pkgs.python37.withPackages (ps: [ps.tensorflow-bin_2 ps.nltk])); in pkgs.stdenv.mkDerivation { name = "my-env-0"; buildInputs = [ py ]; } ================================================ FILE: styx.yaml ================================================ local-packages: typedflow: location: . nix-deps: - QuickCheck - hscolour # non-haskell-deps: # - glibcLocales nixpkgs: # commit: 80812af9e46167e3104038f2af6de251f90823a8 # sha256: 0b718zkn5lhy71pyp0klbz7w872zck0ljqfk17f0b56k3rlvp1sy url: https://github.com/NixOS/nixpkgs/archive/nixos-21.05.tar.gz ================================================ FILE: typedflow.cabal ================================================ name: typedflow version: 0.9 category: Deep Learning synopsis: Typed frontend to TensorFlow and higher-order deep learning description: TypedFlow is a typed, higher-order frontend to TensorFlow and a high-level library for deep-learning. . The main design principles are: . - To make the parameters of layers explicit. This choice makes sharing of parameters explicit and allows to implement "layers" as pure functions. . - To provide as precise as possible types. Functions are explicit about the shapes and elements of the tensors that they manipulate (they are often polymorphic in shapes and elements though.) . - To let combinators be as transparent as possible. If a NN layers is a simple tensor transformation it will be exposed as such. license: LGPL-3 license-file: LICENSE author: Jean-Philippe Bernardy maintainer: jean-philippe.bernardy@gu.se Cabal-Version: >= 1.12 build-type: Simple source-repository head type: git location: git@github.com:GU-CLASP/TypedFlow.git library default-language: Haskell2010 build-depends: base==4.*, ghc-typelits-knownnat, prettyprinter, mtl, containers -- ,tensorflow-opgen, tensorflow, tensorflow-core-ops, tensorflow-ops exposed-modules: TypedFlow, TypedFlow.Layers, TypedFlow.Layers.Core, TypedFlow.Layers.RNN, TypedFlow.Layers.RNN.Base, TypedFlow.Layers.RNN.Cells, TypedFlow.Layers.RNN.Attention, TypedFlow.Learn, TypedFlow.Models.Topic, TypedFlow.Models.Transformer, TypedFlow.Python, TypedFlow.TF, TypedFlow.Types, TypedFlow.Types.Proofs other-modules: TypedFlow.Memo TypedFlow.Memo2 TypedFlow.Abstract TypedFlow.Broadcast ================================================ FILE: typedflow_rts.py ================================================ import tensorflow as tf import numpy as np import sys from time import time import os import random ############################################################### # Devices ############################################################### def cuda_use_device(n): """Attempt to use a given CUDA device by setting the appropriate environment variables""" os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID" if os.environ.get("CUDA_VISIBLE_DEVICES") is None: os.environ["CUDA_VISIBLE_DEVICES"] = str(n) def find_free_cuda_device(): currentGPU = -1 gpuMemory=dict() gpuUtil=dict() for line in os.popen("nvidia-smi -q"): fields = list(map(lambda x: x.strip(), line.split(":"))) k = fields[0] if k == "Minor Number": currentGPU += 1 gpuMemory[currentGPU] = 0 elif k == "Used GPU Memory": gpuMemory[currentGPU] = int(fields[1][:-4]) # last characters are " MiB" elif k == "Gpu": gpuUtil[currentGPU] = fields[1] # last characters are " %" minUse = min(gpuMemory.values()) freeGpus = [g for g in gpuMemory.keys() if gpuMemory[g] == minUse] if freeGpus == []: print("No free GPU could be found.") assert False else: result = random.choice(freeGpus) print ("Found device",result,"currently used at",gpuUtil[result],"and with",gpuMemory[result],"MB taken.") return result def cuda_use_one_free_device(): """Attempt to use a free CUDA device by setting the appropriate environment variables""" cuda_use_device(find_free_cuda_device()) ############################################################### # Generators ############################################################### def bilist_generator(l): """ Given a pair of x and y (each being a list or a np array) and a batch size, return a generator function which will yield the input in bs-sized chunks. Attention: if the size of the input is not divisible by bs, then the remainer will not be fed. Consider shuffling the input. """ (l0,l1) = l def gen(bs): if len(l0) == 0: return for i in range(0, bs*(len(l0)//bs), bs): yield {"x":l0[i:i+bs],"y":l1[i:i+bs]} return gen def bilist_generator_transposed(model,l): ''' Given a pair of l=(x,y) (both x,y being a list or a np array) and a batch size, return a generator function which will yield the input in bs*maxlen-sized chunks. This generator is intended to be used for stateful language models. That is, batch sequencing corresponds to ''' (batch_size,maxlen) = model["x"].shape (xs,ys) = l num_items = len(xs) // (batch_size*maxlen) x = np.zeros(shape=(num_items,batch_size,maxlen)) y = np.zeros(shape=(num_items,batch_size,maxlen)) for i in range(num_items): for j in range(batch_size): for k in range(maxlen): x[i][j][k] = xs[k+j*(num_items*maxlen)+i*maxlen] y[i][j][k] = ys[k+j*(num_items*maxlen)+i*maxlen] def gen(_bs): nonlocal num_items, x, y for i in range(num_items): yield {"x":x[i],"y":y[i]} return gen def dict_generator (xs): k0 = next (iter (xs.keys())) # at least one key is needed total_len = len(xs[k0]) def gen(bs): for i in range(0, bs*(total_len//bs), bs): yield dict((k,xs[k][i:i+bs]) for k in xs) return gen def initialize_params (session,model): '''Initialize the learnable parameters of the model''' # it'd be nice to do: # session.run(tf.variables_initializer(model["params"])) # However this does not initialize the optimizer's variables. So, # instead we do: session.run(tf.local_variables_initializer()) session.run(tf.global_variables_initializer()) def train (optimizer, model_static, model_fn, train_generator=bilist_generator(([],[])), valid_generator=bilist_generator(([],[])), epochs=100, callbacks=[], extraVectors=[]): ''' Train the given model. train_generator: training data valid_generator: validation data epochs: number of epochs callbacks: list of callbacks. Each callback receives an epoch entry (see below). If it returns False then the training is aborted. extraVectors: list of extra vectors to pass to session.run when training. modelPrefix: in case of a multitask/multimodel, give the prefix of the model to use. This function returns a list of epoch entries. Each entry is a dictionary with: - "epoch": current epoch - "val" and "train": dictionaries with - "loss", "accuracy", "error_rate", time", "start_time", "end_time" ''' batch_size = model_static["batch_size"] train_vars = model_static["parameters"] placeholders_info = model_fn["placeholders"] stats = [] def halfEpoch(isTraining): totalAccur = 0 totalLoss = 0 n = 0 print ("Training" if isTraining else "Validation", end="") start_time = time() for inputs in train_generator(batch_size) if isTraining else valid_generator(batch_size): cast_inputs = dict((k,tf.cast(inputs[k], placeholders_info[k]["dtype"])) for k in placeholders_info) # the above forces inputs to be tensors. (It's convenient to pass just lists here) print(".",end="") sys.stdout.flush() with tf.GradientTape() as tape: results = model_fn["function"](tf.constant(isTraining, shape=[]), **{**(model_static["paramsdict"]), **cast_inputs}) loss = results["loss"] accur = results["accuracy"] if isTraining: grads = tape.gradient(loss, train_vars) optimizer.apply_gradients(zip(grads, train_vars)) n+=1 totalLoss += loss totalAccur += accur end_time = time() totalAccur = totalAccur.numpy() totalLoss = totalLoss.numpy() if n > 0: avgLoss = totalLoss / float(n) avgAccur = totalAccur / float(n) print(".") print ("Time=%.1f" % (end_time - start_time), "loss=%g" % avgLoss, "accuracy=%.3f" % avgAccur) return {"loss":avgLoss,"accuracy":avgAccur,"time":(end_time - start_time),"error_rate":1-avgAccur,"start_time":start_time} else: print ("No data") return {"loss":0,"accur":0,"time":0,"error_rate":0,"start_time":0} for e in range(epochs): print ("Epoch {0}/{1}".format(e, epochs)) tr = halfEpoch(True) va = halfEpoch(False) epoch_stats = {"train":tr, "val":va, "epoch":e} stats.append(epoch_stats) if any(c(epoch_stats) for c in callbacks): break return stats def StopWhenValidationGetsWorse(patience = 1): '''Return a callback which stops training if validation loss gets worse.''' bestLoss = 10000000000 p = patience def callback(values): nonlocal bestLoss, p, patience newLoss = values["val"]["loss"] if newLoss > bestLoss: p -= 1 else: bestLoss = newLoss p = patience if p <= 0: return True return False return callback def StopWhenAccurate(phase="val",error_rate = .01): '''Return a callback which stops training if error rate drops below 1%''' def callback(values): nonlocal error_rate return values[phase]["error_rate"] < error_rate return callback def Every(n,f): '''Return a callback which calls its argument every n epochs''' def callback(values): nonlocal n,f if values["epoch"] % n == (n-1): return f(values) else: return False return callback def Save(sess,saver,ckptfile): def callback(values): nonlocal sess,saver print("Saving to",ckptfile) saver.save(sess, ckptfile) return False return callback ################################################################################################ # Prediction and evaluation def evaluate (model_static, model_fn, xs, result="y_"): '''Evaluate the model for given input and result. Input is given as a dictionary of lists to pass to session.run''' phs = model_fn["placeholders"] if phs: k0 = next (iter (phs.keys())) # 1st placeholder total_len = len(xs[k0]) # total length else: total_len = 1 zeros = dict((k,tf.zeros(phs[k]["shape"][1:], # remove the batch size dtype=phs[k]["dtype"])) for k in phs.keys()) results = [] if model_fn["batched"]: def run(): bs = model_static["batch_size"] for i in range(0, bs*(-(-total_len//bs)), bs): print(".",end="") chunks = dict() for k in phs: chunks[k] = xs[k][i:i+bs] if i + bs > total_len: # dealing with an incomplete last chunk origLen = total_len - i for k in chunks: chunks[k] = list(chunks[k]) + [zeros[k]] * (bs - origLen) # pad the last chunk else: origLen = bs chunks = {k: tf.cast(v,dtype=phs[k]["dtype"]) for (k,v) in chunks.items()} results = model_fn["function"](tf.constant(False, shape=[]), **{**(model_static["paramsdict"]), **chunks}) yield results[result][:origLen] return np.concatenate(list(run())) else: def run(): for i in range(total_len): inputs = {k: tf.cast(xs[k][i], dtype=phs[k]["dtype"]) for k in phs} results = model_fn["function"](tf.constant(False, shape=[]), **{**(model_static["paramsdict"]), **inputs}) yield results[result] return list(run()) predict = evaluate def beam_translate(session, model, k, x, xlen, start_symbol, stop_symbol, debug=None): '''Beam translation of ONE input sentence.''' (_,out_len,voc_size) = model["y_"].shape xs = np.array ([x] * k) # The input is always the same xs_len = np.array ([xlen]*k) # it is VERY important to get the length right ys = [[start_symbol]] # start with a single thing; otherwise the minimum will be repeated k times probs = [1] results = [] hist = [[]] def pad(z): return np.array(z + [0] * (out_len - len(z))) for i in range(out_len-1): print ("beam search at:", i) inputs = {"src_len":xs_len[:len(ys)], "src_in":xs[:len(ys)], "tgt_in":np.array([pad(y) for y in ys])} y_s = predict(session,model,inputs) all_words = sorted([(y_s[j][i][w] * probs[j], ys[j] + [w], hist[j] + [y_s[j][i][w]]) for j in range(len(y_s)) for w in range(voc_size)]) best = all_words[-k:] if debug is not None: for x in best: debug(x) results += [(p,y,h) for (p,y,h) in best if y[i+1] == stop_symbol] continued = [(p,y,h) for (p,y,h) in best if y[i+1] != stop_symbol] if len(continued) == 0: break (probs,ys,hist) = zip(*continued) return sorted(results) ###################################################### # Saving and loading def save(model_static, file): numpy_tensors = {k:v.numpy() for (k,v) in model_static["paramsdict"].items()} print("Saving parameters: ", model_static["paramsdict"].keys()) np.savez(file,**numpy_tensors) print("Done") def load(model_static, file): print("Loading parameters") numpy_tensors = np.load(file) print("Loaded parameters: ", list(numpy_tensors.keys())) for k,v in model_static["paramsdict"].items(): v.assign(numpy_tensors[k]) print("Done")