[
  {
    "path": ".gitignore",
    "content": ".styx\n*~\ndist\ndist-*\ncabal-dev\n*.o\n*.hi\n*.chi\n*.chs.h\n*.dyn_o\n*.dyn_hi\n.hpc\n.hsenv\n.cabal-sandbox/\ncabal.sandbox.config\n*.prof\n*.aux\n*.hp\n*.eventlog\n.stack-work/\ncabal.project.local\n.HTF/\n/examples/seq2seq/s2s.py\n/examples/seq2seq/synthtrees.txt\nMNIST_data\n__pycache__\n/examples/seq2seq/GenTr\n/.tramp_history\n"
  },
  {
    "path": "LICENSE",
    "content": "                   GNU LESSER GENERAL PUBLIC LICENSE\n                       Version 3, 29 June 2007\n\n Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>\n Everyone is permitted to copy and distribute verbatim copies\n of this license document, but changing it is not allowed.\n\n\n  This version of the GNU Lesser General Public License incorporates\nthe terms and conditions of version 3 of the GNU General Public\nLicense, supplemented by the additional permissions listed below.\n\n  0. Additional Definitions.\n\n  As used herein, \"this License\" refers to version 3 of the GNU Lesser\nGeneral Public License, and the \"GNU GPL\" refers to version 3 of the GNU\nGeneral Public License.\n\n  \"The Library\" refers to a covered work governed by this License,\nother than an Application or a Combined Work as defined below.\n\n  An \"Application\" is any work that makes use of an interface provided\nby the Library, but which is not otherwise based on the Library.\nDefining a subclass of a class defined by the Library is deemed a mode\nof using an interface provided by the Library.\n\n  A \"Combined Work\" is a work produced by combining or linking an\nApplication with the Library.  The particular version of the Library\nwith which the Combined Work was made is also called the \"Linked\nVersion\".\n\n  The \"Minimal Corresponding Source\" for a Combined Work means the\nCorresponding Source for the Combined Work, excluding any source code\nfor portions of the Combined Work that, considered in isolation, are\nbased on the Application, and not on the Linked Version.\n\n  The \"Corresponding Application Code\" for a Combined Work means the\nobject code and/or source code for the Application, including any data\nand utility programs needed for reproducing the Combined Work from the\nApplication, but excluding the System Libraries of the Combined Work.\n\n  1. Exception to Section 3 of the GNU GPL.\n\n  You may convey a covered work under sections 3 and 4 of this License\nwithout being bound by section 3 of the GNU GPL.\n\n  2. Conveying Modified Versions.\n\n  If you modify a copy of the Library, and, in your modifications, a\nfacility refers to a function or data to be supplied by an Application\nthat uses the facility (other than as an argument passed when the\nfacility is invoked), then you may convey a copy of the modified\nversion:\n\n   a) under this License, provided that you make a good faith effort to\n   ensure that, in the event an Application does not supply the\n   function or data, the facility still operates, and performs\n   whatever part of its purpose remains meaningful, or\n\n   b) under the GNU GPL, with none of the additional permissions of\n   this License applicable to that copy.\n\n  3. Object Code Incorporating Material from Library Header Files.\n\n  The object code form of an Application may incorporate material from\na header file that is part of the Library.  You may convey such object\ncode under terms of your choice, provided that, if the incorporated\nmaterial is not limited to numerical parameters, data structure\nlayouts and accessors, or small macros, inline functions and templates\n(ten or fewer lines in length), you do both of the following:\n\n   a) Give prominent notice with each copy of the object code that the\n   Library is used in it and that the Library and its use are\n   covered by this License.\n\n   b) Accompany the object code with a copy of the GNU GPL and this license\n   document.\n\n  4. Combined Works.\n\n  You may convey a Combined Work under terms of your choice that,\ntaken together, effectively do not restrict modification of the\nportions of the Library contained in the Combined Work and reverse\nengineering for debugging such modifications, if you also do each of\nthe following:\n\n   a) Give prominent notice with each copy of the Combined Work that\n   the Library is used in it and that the Library and its use are\n   covered by this License.\n\n   b) Accompany the Combined Work with a copy of the GNU GPL and this license\n   document.\n\n   c) For a Combined Work that displays copyright notices during\n   execution, include the copyright notice for the Library among\n   these notices, as well as a reference directing the user to the\n   copies of the GNU GPL and this license document.\n\n   d) Do one of the following:\n\n       0) Convey the Minimal Corresponding Source under the terms of this\n       License, and the Corresponding Application Code in a form\n       suitable for, and under terms that permit, the user to\n       recombine or relink the Application with a modified version of\n       the Linked Version to produce a modified Combined Work, in the\n       manner specified by section 6 of the GNU GPL for conveying\n       Corresponding Source.\n\n       1) Use a suitable shared library mechanism for linking with the\n       Library.  A suitable mechanism is one that (a) uses at run time\n       a copy of the Library already present on the user's computer\n       system, and (b) will operate properly with a modified version\n       of the Library that is interface-compatible with the Linked\n       Version.\n\n   e) Provide Installation Information, but only if you would otherwise\n   be required to provide such information under section 6 of the\n   GNU GPL, and only to the extent that such information is\n   necessary to install and execute a modified version of the\n   Combined Work produced by recombining or relinking the\n   Application with a modified version of the Linked Version. (If\n   you use option 4d0, the Installation Information must accompany\n   the Minimal Corresponding Source and Corresponding Application\n   Code. If you use option 4d1, you must provide the Installation\n   Information in the manner specified by section 6 of the GNU GPL\n   for conveying Corresponding Source.)\n\n  5. Combined Libraries.\n\n  You may place library facilities that are a work based on the\nLibrary side by side in a single library together with other library\nfacilities that are not Applications and are not covered by this\nLicense, and convey such a combined library under terms of your\nchoice, if you do both of the following:\n\n   a) Accompany the combined library with a copy of the same work based\n   on the Library, uncombined with any other library facilities,\n   conveyed under the terms of this License.\n\n   b) Give prominent notice with the combined library that part of it\n   is a work based on the Library, and explaining where to find the\n   accompanying uncombined form of the same work.\n\n  6. Revised Versions of the GNU Lesser General Public License.\n\n  The Free Software Foundation may publish revised and/or new versions\nof the GNU Lesser General Public License from time to time. Such new\nversions will be similar in spirit to the present version, but may\ndiffer in detail to address new problems or concerns.\n\n  Each version is given a distinguishing version number. If the\nLibrary as you received it specifies that a certain numbered version\nof the GNU Lesser General Public License \"or any later version\"\napplies to it, you have the option of following the terms and\nconditions either of that published version or of any later version\npublished by the Free Software Foundation. If the Library as you\nreceived it does not specify a version number of the GNU Lesser\nGeneral Public License, you may choose any version of the GNU Lesser\nGeneral Public License ever published by the Free Software Foundation.\n\n  If the Library as you received it specifies that a proxy can decide\nwhether future versions of the GNU Lesser General Public License shall\napply, that proxy's public statement of acceptance of any version is\npermanent authorization for you to choose that version for the\nLibrary.\n"
  },
  {
    "path": "Makefile",
    "content": "\nviewdoc: dist/doc/html/typedflow/index.html\n\txdg-open $<\n\ndist/doc/html/typedflow/index.html:\n\tstyx cabal -- haddock --hyperlink-source\n\tstyx cabal -- hscolour\n\n"
  },
  {
    "path": "README.org",
    "content": "#+TITLE: TypedFlow\n\nTypedFlow is a typed, higher-order frontend to [[http://www.tensorflow.org][TensorFlow]] and a\nhigh-level library for deep-learning.\n\nThe main design principles are:\n\n  - To make the parameters of layers explicit. This choice makes\n    sharing of parameters explicit and allows to implement \"layers\" as\n    pure functions.\n\n  - To provide as precise as possible types. Functions are explicit\n    about the shapes and elements of the tensors that they manipulate\n    (they are often polymorphic in shapes and elements though.)\n\n  - To let combinators be as transparent as possible. If a NN layers\n    is a simple tensor transformation it will be exposed as such.\n\n\nIn this version, the interface to TensorFlow is done via python-code\ngeneration and a suitable runtime system.\n\n** Documentation\n\nThe compiled documentation should be found on [[https://hackage.haskell.org/package/typedflow][hackage]].\n\n** Examples\n\nTypedFlow comes with two examples of neural networks:\n\n - An adaptation of the [[examples/mnist][MNIST tensorflow tutorial]]\n - A simple [[examples/seq2seq][sequence to sequence model]] which\n   attempts to learn to translate pre-order into post-order.\n\nTo running the examples can be done like so:\n\n#+BEGIN_SRC shell\nnix-env -iA nixpkgs.haskellPackages.styx\nnix-env -iA nixpkgs.cabal2nix\nstyx configure\ncd examples/seq2seq\nmake\n#+END_SRC\n\n"
  },
  {
    "path": "TypedFlow/Abstract.hs",
    "content": "{-# LANGUAGE InstanceSigs #-}\n{-|\nModule      : TypedFlow.Abstract\nDescription : Abstract Tensor representations\nCopyright   : (c) Jean-Philippe Bernardy, 2018\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n\nThis module provides operations on the abstract representation of\ntensor operations. It is not normally imported directly by users.\n-}\n\n{-# LANGUAGE ApplicativeDo #-}\n{-# LANGUAGE LambdaCase #-}\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFoldable #-}\n{-# LANGUAGE DeriveFunctor #-}\n{-# LANGUAGE DeriveTraversable #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE GADTs #-}\n{-# LANGUAGE MagicHash #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE StandaloneDeriving #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeInType #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-# LANGUAGE CPP #-}\n#if __GLASGOW_HASKELL__ >= 806\n{-# LANGUAGE NoStarIsType #-}\n#endif\n\nmodule TypedFlow.Abstract where\n\nimport Control.Monad.RWS (RWS, tell, runRWS)\nimport Control.Monad.State\nimport Data.Proxy\nimport Data.Type.Equality\nimport GHC.TypeLits\nimport Prelude hiding (RealFrac(..))\nimport qualified TypedFlow.Memo as Memo0\nimport TypedFlow.Types\nimport TypedFlow.Broadcast\nimport TypedFlow.Types.Proofs\n\nfreeVarsT :: forall s t. KnownTyp t => KnownShape s\n  => T s t -> [Int]\nfreeVarsT x = result\n  where f :: forall s' t'. T s' t' -> [Int]\n        f = Memo0.memo (protoFreevars f)\n        result = f x\n\nprotoFreevars :: (forall s' t'. T s' t' -> [Int]) -> T s t -> [Int]\nprotoFreevars rec = \\case\n  BroadcastT _ _ _ _ x -> rec x\n  MapT _ s f x -> rec x <> rec (f (T (Variable (Ref (-789) s typeSTyp))))\n  Softmax _ _ x -> rec x\n  DirectBroadcast _ _ _ _ x -> rec x\n  GatherND _ _ _ x y -> rec x <> rec y\n  Noise _ _ _ _ -> []\n  Where cond x y -> rec cond <> rec x <> rec y\n  If cond x y ->  rec cond <> rec x <> rec y\n  T (Variable (Ref i _ _)) -> [i]\n  T _ -> []\n  Unbroadcast _p _u x -> rec x\n  UnOp _op _ x -> rec x\n  MatMul _ _ _ _ x y -> rec x <> rec y\n  BinOp _op _ _ _ _ _ x y -> rec x <> rec y\n  Gather _is _s0 _m _s1 x ix -> rec x <> rec ix\n  Transpose _ _t x -> rec x\n  ReshapeFrom _s x -> rec x\n  Concat _s0  _s1 xs -> mconcat $ htoList $ hmap (\\(Catable _ x) -> K (rec x)) xs\n  Convolution _bs _inChans _outChans _filterShape _s x filters -> rec x <> rec filters\n  Pool _ _ _ _ _ x  -> rec x\n  _ -> error \"protoFreevars: unhandled case\"\n\n\n\n\ngenTrainingPlaceholder :: Scalar TFBool\ngenTrainingPlaceholder = T (ExternalVar (Ref \"training_placeholder\" typeSShape typeSTyp))\n\n\n-- | Zeros\nzeros :: ∀ t (shape :: Shape). KnownNumeric t => KnownShape shape => (T shape t)\nzeros = constant $ knownNum @t $ 0\n\ndefaultT :: ∀ t (shape :: Shape). KnownShape shape => KnownTyp t => (T shape t)\ndefaultT = case typeSTyp @t of\n                 STyp SFloat _ _ -> zeros\n                 STyp SInt _ _ -> zeros\n                 STyp SBool _ _ -> constant False\n                 _ -> error \"defaultT: unhandled case\"\n\n\n-- | Ones\nones :: ∀ t (shape :: Shape). KnownShape shape => KnownNumeric t => (T shape t)\nones = knownNum @t $ constant 1\n\n-- | Identity matrix in dimensions n,n\neye :: ∀ n t. KnownNat n => KnownNumeric t => (T '[n,n] t)\neye = diag 1\n\ndiag :: ∀ n t. KnownTyp t => KnownNat n => T '[n] t ->  T '[n,n] t\ndiag = UnOp (Diag Sat) Unit\n\nexpm :: ∀ n t. KnownNat n => KnownNumeric t => T '[n,n] t ->  T '[n,n] t\nexpm = UnOp (ExpM Sat) Unit\n\n-- | @k@=diagonal above which to zero elements. k = 0 is the main diagonal, k < 0 is below it and k > 0 is above.\ntril :: ∀ n t. KnownNat n => KnownNumeric t => Integer -> T '[n,n] t ->  T '[n,n] t\ntril k = UnOp (ZeroTriangle Sat Lower k) Unit\n\ntriu :: ∀ n t. KnownNat n => KnownNumeric t => Integer -> T '[n,n] t ->  T '[n,n] t\ntriu k = UnOp (ZeroTriangle Sat Upper k) Unit\n\n\n-- | Constant\nconstant :: forall s t w. KnownShape s => KnownBits w => KnownKind t => HaskType ('Typ t w) -> T s ('Typ t w)\nconstant c = appRUnit @s #> broadcastTT @s (scalar c)\n\nscalar :: forall t w. KnownBits w => KnownKind t => HaskType ('Typ t w) -> Scalar ('Typ t w)\nscalar = T . Constant\n\nreduceAll :: forall s t. KnownTyp t => KnownShape s =>\n     (∀n s'. (KnownTyp t,KnownShape s') => Axis n s' -> T s' t -> T (Take n s' ++ Drop ('Succ n) s') t) -> Tensor s t -> Tensor '[] t\nreduceAll op x = knownProduct @s ?>\n   op axis0 (reshapeTo ((:*) (productS (typeSShape @s)) Unit) x)\n\n-- | Mean value of the input tensor.\nreduceMeanAll, reduceSumAll, reduceMaxAll, reduceMinAll :: ∀ (s :: Shape) t. KnownNumeric t => KnownShape s => Tensor s t -> Tensor '[] t\nreduceMaxAll = reduceAll reduceMax\nreduceMeanAll = reduceAll reduceMean\nreduceSumAll = reduceAll reduceSum\nreduceMinAll = reduceAll reduceMin\n\nsShapeTake' :: Axis n s -> SList' f s -> SList' f (Take n s)\nsShapeTake' AxZero _s = Unit\nsShapeTake' (AxSucc n) ((:*) x xs) = (:*) x (sShapeTake' n xs)\n\nsShapeDrop' :: Axis n s -> SList' f s -> SList' f (Drop n s)\nsShapeDrop' AxZero s = s\nsShapeDrop' (AxSucc n) ((:*) _ xs) = sShapeDrop' n xs\n\nsShapeDropSucc :: Axis n s -> SList' f s -> SList' f (Drop ('Succ n) s)\nsShapeDropSucc AxZero (_ :* s) = s\nsShapeDropSucc (AxSucc n) (_ :* xs) = sShapeDropSucc n xs\n\n-- | Internal. Use 'reduceSum', etc. instead.\nreduce :: ∀ n s t. KnownNumeric t => (KnownShape s) => ReduceOp -> Axis n s -> T s t -> T (Take n s ++ Drop ('Succ n) s) t\nreduce op n x = case axisSplitApp' n of\n  Refl -> UnOp (Axis1Op (sShapeDropSucc n s) (ReduceOp (hlookup n s) op)) (sShapeTake' n s) x\n where s = typeSShape @s\n\n-- | Reduce along a given dimension\nreduceSum, reduceMean, reduceMax, reduceMin :: ∀n s t. (KnownNumeric t,KnownShape s) => Axis n s -> T s t -> T (Take n s ++ Drop ('Succ n) s) t\nreduceSum = reduce Sum\nreduceMean = reduce Mean\nreduceMax = reduce Max\nreduceMin = reduce Min\n\n\n-- | Sum along the first dimension\nreduceSum0 :: ∀ s' n t. KnownNat n => KnownNumeric t => KnownShape s' => Tensor (n ': s') t -> Tensor s' t\nreduceSum0 = reduceSum axis0\n\n\n\naddN :: ∀ s t. KnownNumeric t => KnownShape s => [Tensor s t] -> Tensor s t\naddN [] = zeros\naddN ts = foldr1 (+) ts\n\ninstance (KnownNumeric t, KnownShape s) => Num (T s t) where\n  (+) = (⊕)\n  (*) = (⊙)\n  signum = unOp Sign\n  fromInteger x = knownNum @t $ constant (fromIntegral x)\n  abs = unOp Abs\n  (-) = (⊝)\n  negate = unOp Negate\n\ninstance (KnownFloat b, KnownShape s) => Fractional (T s b) where\n  fromRational x = knownAlgebraic @b $ constant (fromRational x :: HaskType b)\n  (/) = (⊘)\n\ninstance (KnownFloat b, KnownShape s) => Floating (T s b) where\n  pi = knownAlgebraic @b $ constant pi\n  exp = unFlOp Exp\n  log = unFlOp Log\n  sin = unFlOp Sin\n  cos = unFlOp Cos\n  asin = unFlOp Asin\n  acos = unFlOp Acos\n  sinh = unFlOp Sinh\n  cosh = unFlOp Cosh\n  asinh = unFlOp Asinh\n  acosh = unFlOp Acosh\n  tanh = unFlOp Tanh\n  atan = unFlOp Atan\n  atanh = unFlOp Atanh\n  sqrt = unFlOp Sqrt\n\n-- | Pretend that the argument is a constant for the purposes of\n-- gradient computation\nstopGradient :: ∀ s t. KnownTyp t => KnownShape s => Tensor s t -> Tensor s t\nstopGradient = appRUnit @s #> UnOp StopGradient (typeSShape @s)\n\n-- | Divide tensors, broacasting along shape @s@\n(⊘) :: forall s t. KnownAlgebraic t => KnownShape s => T s t -> T s t -> T s t\n(⊘) = binOp Divide\n\n-- | Divide tensors, broacasting along shape @s@\nfloorDiv :: forall s w. KnownBits w => KnownShape s => T s ('Typ 'Int w) -> T s ('Typ 'Int w) -> T s ('Typ 'Int w)\nfloorDiv = binOp IntegerDiv\n\n\n-- | Indexwise equality test.\nequal :: forall s t. (KnownShape s, KnownTyp t) => Tensor s t -> Tensor s t -> Tensor s TFBool\nequal = binOp (Equal)\n\n-- | Indexwise operator\n(⊕), (⊝), (⊙)  :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s t\n(⊝) = binOp Subtract\n(⊙) = binOp Multiply\n(⊕) = binOp Add\n\nmaxT,minT :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s t\nmaxT = binOp Maximum\nminT = binOp Minimum\n\nmkComplex :: KnownBits w => KnownShape s => Tensor s (Flt w) -> Tensor s (Flt w) -> Tensor s ('Typ 'Cmplx w)\nmkComplex = binOp MkComplex\n\nlessThan :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s TFBool\nlessThan = binOp (Comparision Less)\n\nlessOrEqualThan :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s TFBool\nlessOrEqualThan = binOp (Comparision LessOrEqual)\n\ngreaterThan :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s TFBool\ngreaterThan = binOp (Comparision Greater)\n\nlogicAnd :: ∀ (s :: Shape). (KnownShape s) => Tensor s TFBool -> Tensor s TFBool-> Tensor s TFBool\nlogicAnd = binOp (Logic And)\n\n\ninfixl 7 ⊙,⊘\ninfixl 6 ⊕,⊝\n\n\n-- | Matrix multiplication (note that shape @s@ is preserved)\nmatmul :: forall m n o t. KnownNumeric t => KnownNat m => KnownNat o => KnownNat n => KnownTyp t => T '[n,o] t -> T '[o,m] t -> T '[n,m] t\nmatmul = MatMul Unit Sat Sat Sat\n\nunOp :: forall s t. KnownShape s => KnownNumeric t => Num1Op -> T s t -> T s t\nunOp op = appRUnit @s #> UnOp (Num1Op op)  (typeSShape @s)\n\nunFlOp :: forall s t. KnownBits t => KnownShape s => Float1Op -> T s (Flt t) -> T s (Flt t)\nunFlOp op = appRUnit @s #> UnOp (Float1Op op) (typeSShape @s)\n\nbinOp :: forall s t u. KnownShape s => KnownTyp t => Simple2Op t u -> T s t -> T s t -> T s u\nbinOp op = appRUnit @s #> BinOp (Simple2Op op) (typeSShape @s) Unit typeSTyp Unit typeSTyp\n\nconjugate :: ∀ s w. KnownShape s => KnownBits w => T s ('Typ 'Cmplx w) ->  T s ('Typ 'Cmplx w)\nconjugate = appRUnit @s #> UnOp Conjugate (typeSShape @s)\n\nrealPart :: ∀ s w. KnownShape s => KnownBits w => T s ('Typ 'Cmplx w) ->  T s ('Typ 'Float w)\nrealPart = appRUnit @s #> UnOp RealPart (typeSShape @s)\n\nsigmoid, relu, square, round, floor, hardSigmoid\n   :: ∀ s t. (KnownShape s, KnownBits t)\n   => Tensor s ('Typ 'Float t) -> Tensor s ('Typ 'Float t)\nsigmoid = unFlOp Sigmoid\nhardSigmoid = unFlOp HardSigmoid\nsquare = unOp Square\nrelu = unFlOp Relu\n\nfloorMod :: ∀ s t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s t\nfloorMod = binOp FloorMod\n\n-- Unfortunately RealFrac is utterly broken; so we have to do this:\nround = unFlOp Round\nfloor = unFlOp Floor\n\n-- | Take a slice at dimension n from i to j.\nslice :: forall i j s t n. KnownTyp t => KnownShape s => KnownNat j => KnownNat i => (i <= j, j <= At n s, KnownLen s) =>\n         Axis n s -> Tensor s t -> Tensor (Take n s ++ ((j-i) ': Drop ('Succ n) s)) t\nslice n = case axisSplitApp' n of\n  Refl -> UnOp (Axis1Op (sShapeDropSucc n s) (SliceOp (Proxy @(j-i)) (hlookup n s) (natVal (Proxy @i)) (natVal (Proxy @j))))\n               (sShapeTake' n s)\n where s = typeSShape @s\n\n\nslice1 :: forall i j m n s t. KnownShape s => KnownNat m => KnownNat n => KnownTyp t => KnownNat j => KnownNat i => (i <= j, j <= m, KnownLen s) =>\n         Tensor (n ': m ': s) t -> Tensor (n ': (j-i) ': s) t\nslice1 = slice @i @j axis1\n\nslice0 :: forall i j m s t. KnownShape s => KnownNat m => KnownTyp t => KnownNat j => KnownNat i => (i <= j, j <= m, KnownLen s) =>\n         Tensor (m ': s) t -> Tensor ((j-i) ': s) t\nslice0 = slice @i @j axis0\n\n\n\n-- | Concatenate tensors on dimension @n@. Recommended: use @zipWithTT (concat0 ...)@ instead.\nconcatT :: ∀ n d1 d2 s t. KnownNat d2 => KnownNat d1 => KnownShape s => (KnownTyp t, (d1+d2) ~ At n s) =>\n    Axis n s -> T (Take n s ++ (d1 ': Drop ('Succ n) s)) t -> T (Take n s ++ (d2 ': Drop ('Succ n) s)) t -> T s t\nconcatT n = case axisSplitApp' n of Refl -> concatT' (sShapeTake' n s) d1 d2 (sShapeDropSucc n s)\n  where s = typeSShape @s; d1 = natSat @d1; d2 = natSat @d2\n\n-- | Concatenate tensors on the first dimension\nconcat0, (©) :: ∀ d1 d2 ys t. KnownTyp t => KnownShape ys => KnownNat d2 => KnownNat d1 => (KnownLen ys) => T (d1 ': ys) t -> T (d2 ': ys) t -> T ((d1 + d2) ': ys) t\nconcat0 = concatT axis0\n\n(©) = concat0\n\n-- | Concatenate tensors on the second dimension\nconcat1 :: ∀ n ys d1 d2 t. KnownShape ys => KnownNat n => KnownNat d2 => KnownNat d1 => KnownTyp t => (KnownLen ys) =>  T (n ': d1 ': ys) t -> T (n ': d2 ': ys) t -> T (n ': (d1 + d2) ': ys) t\nconcat1 = concatT axis1\n\n-- | Add an extra dimension at axis (@n@) of size 1.\nexpandDim :: forall n s t. KnownTyp t => KnownShape s => (PeanoNat n <= Length s) => Tensor s t -> Tensor (Take n s ++ (1 ': Drop n s)) t\nexpandDim x =\n  -- Product (Take n s ++ (1 ': Drop n s))\n  prodHomo @(Take n s) @(1' : Drop n s) #>\n  -- Product (Take n s) * Product (Drop n s)\n  prodHomo @(Take n s) @(Drop n s) #>\n  -- Product (Take n s ++ (1 ': Drop n s))\n  takeDrop @s @n #>\n  -- Product s\n  reshapeFrom (typeSShape @s) x\n\n-- +expandDim :: forall n s t. KnownTyp t => KnownShape s => Axis n s -> Tensor s t -> Tensor (Take n s ++ (1 ': Drop n s)) t\n-- +expandDim ax x = case expandDimProof ax s of Refl -> reshapeFrom s x\n\n-- | Add an extra dimension at axis (0) of size 1.\nexpandDim0 :: ∀ s t. KnownShape s => KnownTyp t => KnownLen s => Tensor s t -> Tensor (1 ': s) t\nexpandDim0 = reshape\n\n-- | Add an extra dimension at axis (1) of size 1.\nexpandDim1 :: ∀ n s t. KnownNat n => KnownTyp t => KnownShape s => Tensor (n ': s) t -> Tensor (n ': 1 ': s) t\nexpandDim1 = reshape\n\n\n-- | Flatten all the dimensions of the tensor\nflattenAll :: forall s t. KnownTyp t => KnownShape s => Tensor s t -> Tensor '[Product s] t\nflattenAll = knownProduct @s ?> reshape\n\ninflateAll :: forall s t. KnownTyp t => KnownShape s => Tensor '[Product s] t -> Tensor s t\ninflateAll = knownProduct @s ?> reshape\n\n\nsqueeze0 :: ∀ s t. KnownTyp t => (KnownShape s) => Tensor (1 ': s) t -> Tensor s t\nsqueeze0 = reshape\n\natShape :: SList s -> T s t -> T s t\natShape _ x = x\n\n-- | Reshape a tensor so that the last two dimensions are collapsed\nflattenN2 :: ∀ s m n t. KnownTyp t => (KnownNat m, KnownNat n, KnownShape s) => Tensor (s ++ '[m,n]) t -> Tensor (s ++ '[m*n]) t\nflattenN2  = prodHomo @s @'[m,n] #>\n             prodHomo @s @'[m*n] #>\n             knownAppend @s @'[m*n] ?>\n             knownAppend @s @'[m,n] ?>\n             reshape\n\n-- | Reshape a tensor so that the first three dimensions are collapsed\nflatten3 :: ∀ m n o s t. KnownTyp t => (KnownNat m, KnownNat n, KnownNat o, KnownShape s) => Tensor (m ': n ': o ': s) t -> Tensor (m*n*o ': s) t\nflatten3  =  -- (m * (n * (o * Product s)))\n             prodAssoc @m @n @(o * Product s) #>\n             -- (m * n) * (o * Product s)\n             prodAssoc @(m * n) @o @(Product s) #>\n             -- ((m * n) * o) * Product s\n             reshape\n\n-- | Reshape a tensor so that the first two dimensions are collapsed\nflatten12 :: ∀ m n o s t. KnownTyp t => KnownNat o => (KnownNat m, KnownNat n, KnownShape s) => Tensor (o ': m ': n ': s) t -> Tensor (o ': m*n ': s) t\nflatten12 = prodAssoc @m @n @(Product s) #> reshape\n\n-- | Reshape a tensor so that the first dimension is expanded into three.\ninflate3 :: ∀ m n o s t. KnownTyp t => (KnownNat m, KnownNat n, KnownNat o, KnownShape s) => Tensor (m*n*o ': s) t -> Tensor (m ': n ': o ': s) t\ninflate3 = -- (m * (n * (o * Product s)))\n           prodAssoc @m @n @(o * Product s) #>\n           -- (m * n) * (o * Product s)\n           prodAssoc @(m * n) @o @(Product s) #>\n           -- ((m * n) * o) * Product s\n           reshape\n\n-- | Reshape a tensor so that the first two dimensions are collapsed\ninflate12 :: ∀ m n o s t. KnownTyp t => KnownNat o => (KnownNat m, KnownNat n, KnownShape s) => Tensor (o ': m*n ': s) t -> Tensor (o ': m ': n ': s) t\ninflate12 = prodAssoc @m @n @(Product s) #> reshape\n\n\n-- | Access the last element in a tensor (in the 0th dimension)\nlast0 :: ∀ n s t. KnownShape s => KnownTyp t => KnownNat n => KnownLen s => T (n ': s) t -> Tensor s t\nlast0 = nth0 (natVal (Proxy @n) - 1)\n\n-- | Access the nth element in a tensor (in the 0th dimension)\nnth0 :: ∀ n s t. KnownTyp t => KnownNat n => KnownShape s => Integer -> T (n ': s) t -> Tensor s t\nnth0 i x = UnOp (Axis1Op (typeSShape @s) (AccessOp (natSat @n) i)) Unit x\n\n-- | Access the nth element in a tensor (in the 0th dimension), with a static index\nnth0' :: ∀ n m s t. KnownNat m => KnownTyp t => KnownShape s => KnownNat n => KnownLen s => n < m => T (m ': s) t -> Tensor s t\nnth0' = nth0 (natVal (Proxy @n))\n\nvecToNP :: forall a f n k. (a -> f 1) -> V n a -> (forall xs. Sum xs ~ n => NP f xs -> k) -> k\nvecToNP _f VUnit k = k Unit\nvecToNP f (x :** xs) k = vecToNP f xs $ \\xs' -> k (f x :* xs')\n\nstackT :: ∀ s0 s (n::Nat) t. KnownShape s => KnownShape s0 => KnownNat n => (KnownLen s0) => V n (T (s0 ++ s) t) -> Tensor (s0 ++ (n ': s)) t\nstackT v = vecToNP @(T (s0++s) t) @(Catable s0 s t)\n             (\\x -> (Catable (natSat @1) $ (prodHomoS s0 s #>\n                                            prodHomoS s0 (natSat @1 :* s) #>\n                                            knownAppend @s0 @s ?>\n                                            knownSShape (s0 .+. natSat @1 :* s) ?>\n                                            reshape x)))\n             v $ (Concat (typeSShape @s0)  (typeSShape @s)) \n  where s = typeSShape @s; s0 = typeSShape @s0\n\n\n-- | Concatenate @n@ tensors along the first dimension\nstack0 :: ∀ s (n::Nat) t. KnownNat n => KnownShape s => (KnownLen s) => V n (T s t) -> Tensor (n ': s) t\nstack0 = stackT @'[]\n\n-- | Concatenate @n@ tensors along the second dimension\nstack1 :: ∀ s (n::Nat) m t. KnownNat n => KnownNat m => KnownShape s => (KnownLen s) => V n (T (m ': s) t) -> Tensor (m ': n ': s) t\nstack1 = stackT @'[m]\n\n-- | Concatenate @n@ tensors along the last dimension\nstackN :: ∀ s (n::Nat) t. KnownNat n => KnownShape s => V n (T s t) -> Tensor (s ++ '[n]) t\nstackN = appRUnit @s #>\n         stackT @s @'[]\n\n\n-- | Split a tensors into @n@ tensors along the first dimension\nunstack0 :: ∀ s (n::Nat) t. KnownTyp t => KnownNat n => KnownShape s => (KnownLen s) => Tensor (n ': s) t -> V n (T s t)\nunstack0 x = fmap (`nth0` x) (vcount @n)\n\n-- | Stack a tensor vector. (To be used on literal lists of tensors.)\nlitStack0 :: KnownShape s => KnownLen xs => TV s t xs -> Tensor (Length xs ': s) t\nlitStack0 tv = knownSList tv ?> stack0 $ toV tv\n  where toV :: TV s t xs -> V (Length xs) (T s t)\n        toV Unit = VUnit\n        toV (K x :* xs) = x :** toV xs\n\n-- | Generate a mask of given length for each sequence.\nsequenceMask :: forall maxlen. KnownNat maxlen => Tensor '[] Int32 -> Tensor '[maxlen] TFBool\nsequenceMask lens = mapT (lens `lessThan`) (range @maxlen)\n\n-- | simple broadcasting of a tensor (like a zero-arity map)\nbroadcastT :: forall n s t. KnownShape s => KnownNat n => KnownTyp t => KnownLen s => T s t ->  T (n ': s) t\nbroadcastT x = BroadcastT Nothing False (natSat @n) typeSShape x\n\n-- | simple broadcasting of a tensor\nbroadcastTT :: forall a s t. KnownShape s => KnownTyp t => KnownShape a => KnownLen s => T s t ->  T (a ++ s) t\nbroadcastTT x = prodHomo @a @s #>\n                knownProduct @a ?>\n                knownAppend @a @s ?>\n                reshape (broadcastT @(Product a) x)\n\n\n\n-- | Map a function along the first dimension of a tensor\nmapT :: forall n s r t u. KnownShape s => KnownNat n => KnownTyp t => KnownLen r => KnownLen s\n     => (T s t -> T r u) ->  T (n ': s) t -> T (n ': r) u\nmapT f x = MapT Sat typeSShape f x\n\n-- | Map a function along the few first dimensions of a tensor, given by the first type parameter\nmapTT :: forall a s t r u. KnownShape r => KnownShape a => KnownTyp u => KnownLen r => KnownShape s => KnownTyp t\n  => (T s t -> T r u) ->  T (a ++ s) t -> T (a ++ r) u\nmapTT f x = prodHomo @a @r #>\n            prodHomo @a @s #>\n            knownProduct @a ?>\n            knownAppend @a @r ?>\n            knownAppend @a @s ?>\n            reshape (mapT @(Product a) f (reshape x))\n\n-- | zip  a function along the first dimension of two tensors tensors\nzipWithT :: forall (n :: Nat) (s :: [Nat]) (t :: Typ) (s1 :: [Nat]) (t1 :: Typ) (s2 :: Shape)  (t2 :: Typ).\n            KnownShape s => KnownShape s1 => KnownNat n=> KnownTyp t => KnownTyp t1\n            => (T s t -> T s1 t1 -> T s2 t2)\n            -> Tensor (n ': s) t\n            -> Tensor (n ': s1) t1\n            -> Tensor (n ': s2) t2\nzipWithT f x y = ZipT Sat typeSShape typeSShape f x y\n\n-- | zip  a function along the few first dimensions of a tensor, given by the first type parameter\nzipWithTT :: forall a (s :: [Nat]) (s1 :: [Nat]) (s2 :: Shape) (t :: Typ) (t1 :: Typ)  (t2 :: Typ).\n            KnownTyp t1 => KnownTyp t => KnownShape s => KnownShape s1 => KnownShape a => KnownShape s2 => KnownTyp t2\n            => (T s t -> T s1 t1 -> T s2 t2)\n            -> Tensor (a ++ s) t\n            -> Tensor (a ++ s1) t1\n            -> Tensor (a ++ s2) t2\nzipWithTT f x y = \n            prodHomo @a @s1 #>\n            prodHomo @a @s2 #>\n            prodHomo @a @s #>\n            knownProduct @a ?>\n            knownAppend @a @s1 ?>\n            knownAppend @a @s2 ?>\n            knownAppend @a @s ?>\n            reshape (zipWithT @(Product a) f (reshape x) (reshape y))\n\nzipWith3T :: forall (n :: Nat) (s :: [Nat]) (t :: Typ) (s1 :: [Nat]) (t1 :: Typ) (s2 :: Shape)  (t2 :: Typ) (s3 :: Shape)  (t3 :: Typ).\n             KnownShape s => KnownShape s1 => KnownShape s2 => KnownShape s3 => KnownNat n => KnownTyp t3 => KnownTyp t => KnownTyp t1 => KnownTyp t2\n            => (T s t -> T s1 t1 -> T s2 t2 -> T s3 t3)\n            -> Tensor (n ': s) t\n            -> Tensor (n ': s1) t1\n            -> Tensor (n ': s2) t2\n            -> Tensor (n ': s3) t3\nzipWith3T = Zip3T Sat typeSShape typeSShape typeSShape\n\n-- | Size-preserving convolution operation.\nconvolution :: forall outputChannels filterSpatialShape inChannels s t.\n               KnownShape s => KnownNat inChannels => KnownNat outputChannels => KnownShape filterSpatialShape\n            => KnownAlgebraic t\n            => Length filterSpatialShape <= 3\n            => Length s ~ Length filterSpatialShape\n            => T (s ++ '[inChannels]) t -- ^ input tensor\n            -> T (filterSpatialShape ++ '[inChannels,outputChannels]) t -- ^ filters\n            -> T (s ++ '[outputChannels]) t\nconvolution x filters = knownAppend @s @'[outputChannels] ?>\n                        knownAppend @s @'[inChannels] ?>\n  squeeze0 (Convolution (natSat @1) (natSat @inChannels) (natSat @outputChannels) (typeSShape @filterSpatialShape) (typeSShape @s)\n             (expandDim0 x)\n             filters)\n\n\n-- | Softmax along the first dimension\nsoftmaxInternal :: forall bs n w. KnownNat bs => KnownBits w => KnownNat n => T '[bs,n] ('Typ 'Float w) -> T '[bs,n] ('Typ 'Float w)\nsoftmaxInternal = Softmax (natSat @bs) (natSat @n)\n\nsoftmax0 :: forall n w.  KnownBits w => KnownNat n\n         => T '[n] (' Typ 'Float w) -> T '[n] ('Typ 'Float w)\nsoftmax0 = reshape . softmaxInternal . reshape @[1,n]\n\n-- | Softmax along the second dimension\nsoftmax1 :: forall n m w.  KnownBits w => KnownNat n => KnownNat m\n         => T '[m,n] ('Typ 'Float w) -> T '[m,n] ('Typ 'Float w)\nsoftmax1 = mapT softmax0\n\nargmaxInternal :: forall n s0 s1 t u. KnownNat n => KnownNumeric t => KnownBits u => Sat KnownNat n -> SShape s0 -> SShape s1 -> T (s0 ++ (n ': s1)) t -> T (s0 ++ s1) ('Typ 'Int u)\nargmaxInternal _n s0 s1 = UnOp (Axis1Op s1 (ArgMax (natSat @n))) s0\n\naxisSplitApp :: Axis n s -> (Take n s ++ Drop n s) :~: s\naxisSplitApp AxZero = Refl\naxisSplitApp (AxSucc n) = case axisSplitApp n of\n  Refl -> Refl\n\n\naxisSplitApp' :: Axis n s -> (Take n s ++ (At n s ': Drop ('Succ n) s)) :~: s\naxisSplitApp' AxZero = Refl\naxisSplitApp' (AxSucc n) = case axisSplitApp' n of\n  Refl -> Refl\n\n\n-- | Argmax along axis @n@\nargmax :: forall m n u s t. (KnownShape s, KnownBits u, KnownNat m, KnownNumeric t) => Axis n s -> Tensor (Take n s ++ (m ': Drop n s)) t -> Tensor s ('Typ 'Int u)\nargmax n = case axisSplitApp n of\n  Refl -> argmaxInternal (natSat @m) (sShapeTake' n (typeSShape @s)) (sShapeDrop' n s)\n  where s = typeSShape @s\n\n-- | Argmax along the first dimension\nargmax0 :: forall u n s t. (KnownNat n, KnownShape s, KnownBits u, KnownNumeric t) => T (n ': s) t -> T s ('Typ 'Int u)\nargmax0 = argmaxInternal (natSat @n) Unit (typeSShape @s)\n\n-- | Argmax along the second dimension\nargmax1 :: forall u m n s t. (KnownNat n, KnownNat m, KnownShape s, KnownBits u, KnownNumeric t) => T (m ': n ': s) t -> T (m ': s) ('Typ 'Int u)\nargmax1 = argmaxInternal (natSat @n) (natSat @m :* Unit) (typeSShape @s)\n-- argmax1 = mapT argmax0 -- equivalent?\n\n-- | Cast the element type.\ncast :: forall u s t. KnownTyp t => KnownShape s => KnownTyp u => T s t -> T s u\ncast = appRUnit @s #> UnOp Cast (typeSShape @s)\n\n-- | (dense) softmax cross entropy with logits.\nsoftmaxCrossEntropyWithLogits :: forall numClasses.\n     KnownNat numClasses => Tensor '[numClasses] Float32 -- ^ labels\n  -> Tensor '[numClasses] Float32 -- ^ logits\n  -> Tensor '[] Float32\nsoftmaxCrossEntropyWithLogits  =\n  BinOp SoftmaxCrossEntropyWithLogits\n  Unit (typeSShape @ '[numClasses]) typeSTyp (typeSShape @ '[numClasses]) typeSTyp\n\n\n-- | Computes sigmoid cross entropy given logits. Measures the\n-- probability error in discrete classification tasks in which each\n-- class is independent and not mutually exclusive. For instance, one\n-- could perform multilabel classification where a picture can contain\n-- both an elephant and a dog at the same time. See\n-- https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits\nsigmoidCrossEntropyWithLogits :: forall s w.\n  KnownBits w => KnownShape s => Tensor s (Flt w) -- ^ labels\n                              -> Tensor s (Flt w) -- ^ logits\n                              -> Tensor s (Flt w)\nsigmoidCrossEntropyWithLogits  =\n  appRUnit @s #> BinOp SigmoidCrossEntropyWithLogits \n    (typeSShape @s)      Unit typeSTyp Unit typeSTyp\n\n-- | sparse softmax cross entropy with logits.\nsparseSoftmaxCrossEntropyWithLogits :: forall numClasses t.\n   KnownNat numClasses => KnownBits t =>\n  Tensor '[] Int32                   -- ^ desired label\n  -> Tensor '[numClasses] (Flt t) -- ^ predictions for each label\n  -> Tensor '[] (Flt t) \nsparseSoftmaxCrossEntropyWithLogits  =\n  BinOp SparseSoftmaxCrossEntropyWithLogits Unit Unit typeSTyp (typeSShape @ '[numClasses]) typeSTyp\n\nreverseT :: KnownTyp t => KnownNat n => T '[n] t -> T '[n] t\nreverseT = UnOp (Axis1Op Unit (ReverseT Sat)) Unit\n\n-- | One hot vector along axis 0\noneHot0 :: forall numClasses w s t. KnownNat numClasses => KnownNumeric t => KnownBits w =>\n  (KnownShape s) =>\n  Tensor s ('Typ 'Int w) -> Tensor (numClasses ': s) t\noneHot0 = UnOp (Axis1Op (typeSShape @s) (OneHot Sat)) Unit\n\n-- | One hot vector along axis 1\noneHot1 :: forall numClasses w s m t. KnownBits w =>KnownShape s => KnownNat numClasses => KnownNat m => KnownNumeric t => Tensor (m ': s) ('Typ 'Int w) -> Tensor (m ': numClasses ': s) t\noneHot1 = mapT oneHot0\n\n-- | Generate a random tensor whose distribution is given. A new noise\n-- is sampled for each element in a batch.\nnoise :: KnownShape s => Distribution s t -> Gen (T s t)\nnoise d = do\n  noiseId <- GPId -- necessary for correct broadcasting behaviour\n  return $ Noise noiseId Unit typeSShape d\n\n-- | Clip a tensor\nclipByValue :: forall s t. KnownShape s => KnownBits t => Float -> Float -> T s (Flt t) -> T s (Flt t)\nclipByValue lo hi = appRUnit @s #> UnOp (Float1Op (ClipByValue lo hi)) (typeSShape @s)\n\n-- | (where_ c x y)[i] = if c[i] then x[i] else y[i]\nwhere_ :: T s TFBool -> T s t -> T s t -> T s t\nwhere_ = Where\n\n\n-- | Selection of a tensor (note: this is a strict operation)\nif_ :: forall s t. KnownShape s => Scalar TFBool -> T s t -> T s t -> T s t\nif_ = If -- FIXME: part of the workaround for https://github.com/tensorflow/tensorflow/issues/21901\n-- if_ x = appRUnit @s $ where_ (broadcastTT @s x)\n\n-- | @(gather x ix)[k] = x[ix[k]]@. See https://www.tensorflow.org/api_docs/python/tf/gather\ngather :: forall n indexShape s t. KnownShape s => KnownNat n => KnownShape indexShape => T (n ': s) t -> T indexShape Int32 -> T (indexShape ++ s) t\ngather = Gather typeSShape Unit (natSat @n) typeSShape\n-- gather params ix = GatherND (typeSShape @'[n]) (typeSShape @s) (typeSShape @indexShape) params $\n--   prodHomo @indexShape @'[1] $\n--   (reshapeAuto ix)\n\n-- | @(lookup i xs) = xs[i]@. This function returns an element of a\n-- tensor at a dynamic index. This is a version of 'gather'\n-- specialised to a scalar index.\nlookupT :: KnownShape xs => KnownNat n => Scalar Int32 -> Tensor (n ': xs) t -> Tensor xs t\nlookupT ix xs = gather xs ix\n\n-- | x by y maxpool layer.\nmaxPool2D :: forall windowx windowy height width channels t.\n             KnownNat height => KnownNat width\n          => KnownNat channels\n          => (KnownNat windowx, KnownNat windowy, KnownBits t) =>\n             T '[windowx*width,windowy*height,channels] (Flt t)\n          -> T '[width,height,channels] (Flt t)\nmaxPool2D x = squeeze0 (Pool (natSat @1) (typeSShape @'[windowx,windowy]) MaxPool (natSat @channels) (typeSShape @'[width,height]) (expandDim0 x))\n\n-- | maxpool layer. window size is the first type argument.\nmaxPool1D :: forall window width channels t.\n             KnownNat width => KnownNat channels => (KnownNat window,KnownBits t) =>\n             T '[window*width,channels] (Flt t) -> T '[width,channels] (Flt t)\nmaxPool1D x = squeeze0 (Pool (natSat @1) (typeSShape @'[window]) MaxPool (natSat @channels) (typeSShape @'[width]) (expandDim0 x))\n\n\ndoExtractVars :: Gen a -> (a, GState, [VarInfo])\ndoExtractVars p = runRWS (extractVars p) () initialGstate\n\nextractVars :: Gen a -> RWS () [VarInfo] GState a\nextractVars (GPState f) = state f\nextractVars GPId = do\n  GState {..} <- get\n  put GState {nextVar=nextVar+1,..}\n  return nextVar\nextractVars (GPVariable trainable name i) = do\n  -- i <- mapM extractVars initial\n  case i of\n    Nothing -> return ()\n    Just i' -> when (not (null (freeVarsT i'))) $ error \"aaaaaaaaarrrrghhh\"\n  GState {..} <- get\n  let r = Ref name typeSShape typeSTyp\n  tell [VarInfo trainable r i]\n  return r\nextractVars (GPApp a b) = do f <- extractVars a; x <- extractVars b; return (f x)\nextractVars (GPBind a f) = do\n  a' <- extractVars a\n  extractVars (f a')\nextractVars (GPReturn x) = return x\n\n"
  },
  {
    "path": "TypedFlow/Broadcast.hs",
    "content": "{-# LANGUAGE InstanceSigs #-}\n{-|\nModule      : TypedFlow.Abstract\nDescription : Abstract Tensor representations\nCopyright   : (c) Jean-Philippe Bernardy, 2018\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n\nThis module provides operations on the abstract representation of\ntensor operations. It is not normally imported directly by users.\n-}\n\n{-# LANGUAGE ApplicativeDo #-}\n{-# LANGUAGE LambdaCase #-}\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFoldable #-}\n{-# LANGUAGE DeriveFunctor #-}\n{-# LANGUAGE DeriveTraversable #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE GADTs #-}\n{-# LANGUAGE MagicHash #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE StandaloneDeriving #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeInType #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-# LANGUAGE CPP #-}\n#if __GLASGOW_HASKELL__ >= 806\n{-# LANGUAGE NoStarIsType #-}\n#endif\n\nmodule TypedFlow.Broadcast (\n  -- * broadcasting\n  doBroadcast,doBroadcastSingle,mapPlaceHolders, ConsSh, unopInputShape,\n  -- * helpers which are also useful elsewhere\n  -- ** reshapes\n  reshape, reshapeAuto, reshapeFrom, reshapeTo, inflate2, flatten2,\n  permToFun, \n  -- ** transpositions\n  transpose01, transposeN, transposeN', transposeN01,\n  -- ** others\n  concatT', range,\n  ) where\n\nimport Control.Monad.State\n-- import Data.Kind (Type,)\nimport Data.Proxy\nimport Data.Type.Equality\nimport GHC.TypeLits\nimport Prelude hiding (RealFrac(..))\nimport System.IO.Unsafe\nimport TypedFlow.Memo2 hiding (Comp)\nimport TypedFlow.Types (T(..), type (∘)(..))\nimport TypedFlow.Types hiding (T)\nimport TypedFlow.Types.Proofs\n\n\ndata GS = GS { gsUnique :: Unique }\n\n\ntype G = StateT GS IO \n\nrunG :: Unique -> G x -> x\nrunG u m = fst (unsafePerformIO  (runStateT m GS { gsUnique = u }))\n\ndoBroadcastSingle :: forall s t. (KnownShape s, KnownTyp t) => T s t -> T s t\ndoBroadcastSingle x = case doBroadcast @'[ '(\"_doBroadcastSingle\" , s , t) ] (PHT x :* Unit) of\n  PHT x' :* Unit -> x'\n  \n  \n\ndoBroadcast :: All KnownPlaceholder ps => Placeholders ps -> Placeholders ps\ndoBroadcast phs = runG 0 $ do\n  F3m' bc <- mkBroadcastFn\n  let broadcast :: forall n s t. BroadcastFn n s t\n      broadcast = unwrapBCFn bc\n  F2m' gBC' <- mkGenerateBC broadcast\n  let generateBC :: forall s t. GenBCFn s t\n      generateBC = unwrapGBCFn gBC'\n      generateBCMany :: forall ps. All KnownPlaceholder ps => Placeholders ps -> G (Placeholders ps)\n      generateBCMany = \\case\n        Unit -> return Unit\n        (PHT x :* xs) -> do\n          x' <- generateBC x\n          xs' <- generateBCMany xs\n          return (PHT x' :* xs')\n  generateBCMany phs\n\n\ngetUnique :: G Unique\ngetUnique = do\n  u <- gets ((1+) . gsUnique)\n  modify $ \\GS {} -> GS {gsUnique = u,..}\n  return u\n\ngenerateBC' :: (forall n s t proxy. KnownTyp t => KnownShape s => KnownNat n => Unique -> Bool -> proxy n -> T s t -> G (T (n : s) t))\n            -> (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> G (T s' t'))\n            -> forall s t. KnownTyp t\n            => SShape s\n            -> T s t\n            -> G (T s t)\ngenerateBC' broadcast rec (n@Sat :* sR) (Zip3T _ _s1 _s2 _s3 f x y z) = knownSShape sR ?> do\n  u <- getUnique\n  -- ATTN: it is critical not to do recursive calls to x,y,z here. Doing so would create new nodes, loosing sharing, and creating problems down the line.\n  a' <- rec sR (f (Unbroadcast n u x) (Unbroadcast n u y) (Unbroadcast n u z))\n  broadcast u False n a'\ngenerateBC' broadcast rec (n@Sat :* sR) (ZipT _ _s1 _s2 f x y) = knownSShape sR ?> do\n  u <- getUnique\n  a' <- rec sR (f (Unbroadcast n u x) (Unbroadcast n u y))\n  broadcast u False n a'\ngenerateBC' broadcast rec (n@Sat :* sR) (MapT _ _s' f x) = knownSShape sR ?> do\n  u <- getUnique\n  a' <- rec sR (f (Unbroadcast n u x))\n  broadcast u False n a'\ngenerateBC' broadcast rec (n@Sat :* sR) (BroadcastT maybeUnique varyNoise _ _s' a) = knownSShape sR ?> do\n  u <- case maybeUnique of\n          Nothing -> getUnique\n          Just u' -> return u'\n  a' <- rec sR a\n  broadcast u varyNoise n a'\ngenerateBC' _ _ _ (n@T {}) = return n\ngenerateBC' _ _ _ (n@Noise {}) = return n\ngenerateBC' _ rec _ (BinOp op s0 s1 t1 s2 t2 x y) = knownTyp t1 $ knownTyp t2 $ BinOp op s0 s1 t1 s2 t2 <$> (rec (s0 .+. s1) x) <*> (rec (s0 .+. s2) y)\ngenerateBC' _ rec _ (UnOp op s0 x) = UnOp op s0 <$> rec (s0 .+. unopInputShape op) x\ngenerateBC' _ rec sR (Unbroadcast p u' x) = Unbroadcast p u' <$> rec (p :* sR) x\ngenerateBC' _ rec _ (DirectBroadcast s0 s1 s2 s3 x) = DirectBroadcast s0 s1 s2 s3 <$> (rec (s0 .+. s2) x)\ngenerateBC' _ rec _ (ReshapeFrom s0 x) = reshapeFrom s0 <$> rec s0 x\ngenerateBC' _ rec _ (Transpose s0 t x) = Transpose s0 t <$> (rec s0 x)\ngenerateBC' _ rec _ (Concat s0 s1 xs) = Concat s0 s1 <$> hTraverse (\\(Catable m x) -> Catable m <$> (rec (s0 .+. m :* s1) x)) xs\ngenerateBC' _ rec _ (Gather is s0 m s1 x ix) = Gather is s0 m s1 <$> (rec (s0 .+. m :* s1) x) <*> rec (s0 .+. is) ix\ngenerateBC' _ rec _ (GatherND cs es is x ix) = GatherND cs es is <$> (rec (cs .+. es) x) <*> (rec (is *: sListLenAsNat cs) ix)\ngenerateBC' _ rec _ (MatMul s0 a b c x y) = MatMul s0 a b c <$> (rec (s0 .+. a :* b :* Unit) x) <*> (rec (s0 .+. b :* c :* Unit) y)\ngenerateBC' _ rec sR (Where cond x y) = Where <$> rec sR cond <*> rec sR x <*> rec sR y\ngenerateBC' _ rec sR (If cond x y) = If <$> rec Unit cond <*> rec sR x <*> rec sR y\ngenerateBC' _ rec _ (Convolution bs@Sat inChans outChans filterShape s0 x filters) = Convolution bs inChans outChans filterShape s0 <$> (rec (bs :* (s0 *: inChans)) x) <*> (rec (filterShape .+. inChans :* outChans :* Unit) filters)\ngenerateBC' _ rec _ (Pool bs@Sat window pt numChans outSpatial x) = Pool bs window pt numChans outSpatial <$> rec (bs :* (zipWithMulSShapes window outSpatial *: numChans)) x\ngenerateBC' _ rec _ (Softmax bs n x) = Softmax bs n <$> (rec (bs :* n :* Unit) x)\ngenerateBC' _ _ _ _ = error \"generateBC': unhandled case\"\n\n\n\n(<&&>) :: Applicative f => f Bool -> f Bool -> f Bool\nx <&&> y = (&&) <$> x <*> y\n\n-- | True if the argument does not contain an expression which should be broadcast.\nprotoFinished :: Unique -> Bool -> (forall s' t'. Unique -> Bool -> T s' t' -> G Bool) -> T s t -> G Bool\nprotoFinished u varyNoise rec0 =\n  let rec :: forall s t. T s t -> G Bool\n      rec = rec0 u varyNoise\n  in \\case\n    BroadcastT _ _ _ _s a -> rec a\n    MapT _ s f x -> rec x <&&> rec (f (T (Variable (Ref 0 s typeSTyp))))\n    ZipT _ s0 s1 f x y -> rec x <&&> rec y <&&> rec (f (T (Variable (Ref 0 s0 typeSTyp))) (T (Variable (Ref 0 s1 typeSTyp))))  \n    Zip3T _ s0 s1 s2 f x y z -> rec x <&&> rec y <&&> rec z <&&> rec (f (T (Variable (Ref 0 s0 typeSTyp))) (T (Variable (Ref 0 s1 typeSTyp))) (T (Variable (Ref 0 s2 typeSTyp))))  \n    Softmax _ _ x -> rec x\n    DirectBroadcast _ _ _ _ x -> rec x\n    GatherND _ _ _ x y -> rec x <&&> rec y\n    Noise _ _ _ _ -> return (not varyNoise)\n    Where cond x y -> rec cond <&&> rec x <&&> rec y\n    If cond x y ->  rec cond <&&> rec x <&&> rec y\n    T _ -> return True\n    Unbroadcast _p u' _x -> return (u /= u')\n    UnOp _op _ x -> rec x\n    MatMul _ _ _ _ x y -> rec x <&&> rec y\n    BinOp _op _ _ _ _ _ x y -> rec x <&&> rec y\n    Gather _is _s0 _m _s1 x ix -> rec x <&&> rec ix\n    Transpose _ _t x -> rec x\n    ReshapeFrom _s x -> rec x\n    Concat _s0  _s1 xs -> (and . htoList) <$> hTraverse (\\(Catable _ x) -> K <$> rec x) xs\n    Convolution _bs _inChans _outChans _filterShape _s x filters -> rec x <&&> rec filters\n    Pool _ _ _ _ _ x  -> rec x\n    -- _ -> error \"protoFinished: unhandled case\"\n\ndata K02 t x y = K02 {fromK02 :: t}\n\n\nmkFinished :: G (F2m G (Sig02 Bool (Sig02 Unique T)) (K02 Bool) ) -- forall s' t'. Unique -> Bool -> T s' t' -> G (F2m _)\nmkFinished = memo2 (ordMap @Bool `containing02` (ordMap @Unique `containing02` snMap2 @T)) $\n             \\rec (Ex02 u (Ex02 v x)) -> K02 <$> protoFinished v u (unwrapFin rec) x\n\nunwrapFin :: ((Sig02 Bool (Sig02 Unique T)) s t -> G (K02 Bool s t)) -> Unique -> Bool -> T s t -> G Bool\nunwrapFin f u v x = fromK02 <$> f (Ex02 v (Ex02 u x))\n\ndata KT s t where\n  KT ::  STyp t -> SShape s -> KT s t\n\ntype GenBCFn s t = (KnownTyp t, KnownShape s) => T s t -> G (T s t)\n\n\nunwrapGBCFn :: forall s t. (T s t -> KT s t -> G (T s t)) -> GenBCFn  s t\nunwrapGBCFn f x' = f x' (KT typeSTyp typeSShape)\n\n-- isBroadcastT :: T s t -> Bool\n-- isBroadcastT (BroadcastT {}) = True\n-- isBroadcastT _ = False\n\nmkGenerateBC :: (forall n s t. BroadcastFn n s t) -> G (F2m' G T KT T)\nmkGenerateBC broadcast = memo2' (snMap2 @T) $\n                         \\rec x (KT t s) -> knownTyp t $ do\n                           r <- generateBC' broadcast (\\sh' x' -> rec x' (KT typeSTyp sh')) s x\n                           -- when (isBroadcastT r) $ liftIO $ putStrLn \"YIKES!\"\n                           return r\n\nnewtype BC'd (n :: Nat) (s :: Shape) (t :: Typ) = BC'd {fromBC'd :: (T (n : s) t)}\n\ndata KTn n s t where\n  KTn ::  STyp t -> SShape s -> KTn n s t\n\ntype BroadcastFn n s t = forall proxy. (KnownNat n, KnownShape s, KnownTyp t) => Unique -> Bool -> proxy n -> T s t -> G (T (n : s) t)\n\nunwrapBCFn :: ((Sig03 Unique (Sig03 Bool (Sig12 (Sat KnownNat) T))) n s t -> KTn n s t -> G (BC'd n s t)) -> BroadcastFn n s t\nunwrapBCFn f u v _n x' = fromBC'd <$> f (Ex03 u (Ex03 v (Ex12 natSat x'))) (KTn typeSTyp typeSShape)\n\nmkBroadcastFn :: G (F3m' G (Sig03 Unique (Sig03 Bool (Sig12 (Sat KnownNat) T))) KTn BC'd)\nmkBroadcastFn = do\n  F2m fin <- mkFinished\n  memo3' (ordMap @Unique `containing03` (ordMap @Bool `containing03` (verifMap1 @(Sat KnownNat) `containing12` snMap2 @T))) $\n    \\rec (Ex03 u (Ex03 v (Ex12 n x))) (KTn st sh) ->\n      BC'd <$> protoBroadcast u v n\n                  (\\sh' x' -> fromBC'd <$> rec (Ex03 u (Ex03 v (Ex12 n x'))) (KTn typeSTyp sh'))\n                  (unwrapFin fin u v) st sh x\n\n\nclass ConsSh (x :: Nat) (p :: (Symbol,Shape,Typ))\ninstance Fun (ConsSh x) where type Ap (ConsSh x) p = '(Frst3 p,x ': Scnd3 p,Thrd3 p)\n\n-- -- | Turns a tensor of indices in a container into a tensor of indices\n-- -- in a container of higher rank. The added indexed dimension\n-- -- corresponds to the first dimension of the index.\n-- broadcastIndex :: forall n containerShape indexShape w.\n--   KnownBits w => Sat KnownNat n ->\n--   SShape containerShape ->\n--   SShape indexShape ->\n--   IndexTensor (n ': indexShape) containerShape w ->\n--   IndexTensor (n ': indexShape) (n ': containerShape) w\n-- broadcastIndex n cs = broadcastIndex' n (sListLenAsNat cs)\n\nbroadcastIndex' :: forall n containerRank indexShape w.\n  KnownBits w => Sat KnownNat n ->\n  Sat KnownNat containerRank ->\n  SShape indexShape ->\n  T (n ': indexShape ++ '[containerRank])  ('Typ 'Int w) ->\n  T (n ': indexShape ++ '[1 + containerRank]) ('Typ 'Int w)\nbroadcastIndex' n@(Comp Dict) cr is ix = concatT' ((:*) n is) (natSat @1) cr Unit nIndex ix\n  where nIndex :: T (n ': indexShape ++ '[1]) ('Typ 'Int w)\n        nIndex = DirectBroadcast Unit Unit ((:*) n Unit) (is .+. (:*) (natSat @1) Unit) range\n\n-- directBroadcast0 :: forall n s t. KnownShape s => KnownNat n => T s t -> T (n:s) t\n-- directBroadcast0 = appRUnit @s #> DirectBroadcast Unit ((:*) (natSat @n) Unit) (typeSShape @s) Unit\n\n-- broadcastIndexMany :: forall n containerShape indexShape w.\n--   KnownBits w =>\n--   Sat KnownNat n ->\n--   SShape containerShape ->\n--   SShape indexShape ->\n--   IndexTensor indexShape '[n] w ->\n--   IndexTensor (containerShape ++ indexShape) (containerShape ++ '[n]) w\n-- broadcastIndexMany _ Unit _ x = x\n-- broadcastIndexMany n ((:*) m@Sat cs) is x =\n--   knownSShape (cs .+. (*:) is (sListLenAsNat (cs *: n))) ?>\n--   -- (m : cs ++ is ++  '[(Length (m : cs ++ [n]))])\n--   (broadcastIndex m ((*:) cs n) (cs .+. is) $\n--   -- (m : (cs ++ is ++  '[Length (cs ++ [n])]))\n--   (appAssocS cs is ((:*) (sListLenAsNat (cs *: n)) Unit) #>\n--   -- (m : cs ++ is ++ '[Length (cs ++ [n])])\n--   directBroadcast0 $\n--   -- (cs ++ is ++  '[Length (cs ++ [n])])\n--   broadcastIndexMany n cs is x))\n--   -- is\n\n--  Product (filterSpatialShape ++ '[inChannels, outChannels * n])\n-- Product ((filterSpatialShape ++ '[inChannels, outChannels]) ++ '[n])\n\naxisOpInputShape :: Axis1Op s1 t s2 u -> SShape s1\naxisOpInputShape o = case o of\n  ArgMax n -> HSingle n\n  OneHot _n -> Unit\n  ReduceOp n _ -> HSingle n\n  ReverseT n -> HSingle n\n  SliceOp _ n _ _ -> HSingle n\n  AccessOp n _ -> HSingle n\n\nunopInputShape :: UnOp s t s' t' -> SShape s\nunopInputShape (Diag n) = n :* Unit\nunopInputShape Cast = Unit\nunopInputShape (Axis1Op s o) = axisOpInputShape o .+. s\nunopInputShape StopGradient = Unit\nunopInputShape (Num1Op _) = Unit\nunopInputShape (Float1Op _) = Unit\nunopInputShape (ExpM n) = n :* n :* Unit\nunopInputShape (ZeroTriangle n _ _) = n :* n :* Unit\nunopInputShape Conjugate = Unit\nunopInputShape RealPart = Unit\n\nprotoBroadcast :: forall n s t.\n  Unique -- unique identifier marking the variable tensor which will be marking inputs (not to broadcast).\n  -> Bool -- how to expand the noise? (If True use different noise for all indices)\n  -> Sat KnownNat n -- added dimension's size\n  -> (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> G (T (n ': s') t')) -- recursive case\n  -> (forall s' t'. T s' t' -> G Bool) -- test if we're done\n  -> STyp t -- representation of the type\n  -> SShape s -- representation of the shape\n  -> T s t -- tensor (expression) to broadcast\n  -> G (T (n ': s) t) -- return broadcated expression (on 1st position)\nprotoBroadcast u varyNoise n@(Comp Dict) rec finished ty s tensor = do\n  isFinished <- finished tensor\n  case isFinished of\n    True -> simpleBC\n    False -> knownTyp ty $ case tensor of\n      BroadcastT {} -> error \"BroadcastT case remaining, this should have been dealt with by generateBC\"\n      MapT {} -> error \"MapT case remaining, this should have been dealt with by generateBC\"\n      ZipT {} -> error \"ZipT case remaining, this should have been dealt with by generateBC\"\n      Zip3T {} -> error \"Zip3T case remaining, this should have been dealt with by generateBC\"\n      Softmax bs@Sat m@Sat x -> prodAssocS n bs m #> do\n        x' <- rec (typeSShape) x\n        return (reshapeAuto (Softmax (satMul n bs) m (reshapeAuto x')))\n      DirectBroadcast s0 s1 s2 s3 x -> do\n        x' <- (rec (s0 .+. s2) x)\n        return (DirectBroadcast (n :* s0) s1 s2 s3 x')\n      GatherND cs es is x ix -> do\n        xFinished <- finished x\n        case xFinished of\n          True -> GatherND cs es (n :* is) x <$> (rec (is *: sListLenAsNat cs) ix)\n          False -> do\n              ix' <- rec (is *: sListLenAsNat cs) ix\n              x' <- (rec (cs .+. es) x)\n              return (GatherND (n :* cs) es (n :* is) x' (broadcastIndex' n (sListLenAsNat cs) is ix'))\n      Noise v s0 s1 x -> if varyNoise then return (Noise v (n :* s0) s1 x) else simpleBC\n      -- When varying noise, then we extend the shape of the noise (so\n      -- more stuff is sampled), otherwise we copy the noise using simple\n      -- broadcasting\n      Pool bs@Sat window pt numChans outSpatial x ->\n        (knownSShape (zipWithMulSShapes window outSpatial *: numChans) ?>\n         (prodAssocS n bs (productS (zipWithMulSShapes window outSpatial *: numChans)) #>\n         (prodAssocS n bs (productS (outSpatial *: numChans)) #> do\n             x' <- (rec typeSShape x)\n             return $ (reshapeFrom (satMul n bs :* outSpatial *: numChans) $\n                       Pool (satMul n bs) window pt numChans outSpatial (reshapeAuto x')))))\n      Where cond x y -> Where <$> (rec s cond) <*> (rec s x) <*> (rec s y)\n      If cond x y -> do\n        condFinished <- finished cond\n        case condFinished of\n          True -> If cond <$> (rec s x) <*> (rec s y)\n          False ->  error \"broadcast on 'if' condition not implemented\"\n      T _ -> error \"panic: broadcast constant should be finished!\"\n      Unbroadcast p@Sat u' x\n        | u == u' -> return $ case testEq p n of\n            Nothing -> UnOp (error \"panic.unbroadcast.unit\") Unit x\n            Just Refl -> x\n        | otherwise -> knownSShape s ?> do\n            x' <- (rec (p :* s) x)\n            return (Unbroadcast p u' (transpose01 x'))\n          -- An uncomplete broadcast (in another dimension).\n      MatMul s0 a@Sat b@Sat c@Sat x y -> do\n        yFinished <- finished y\n        case (s0,yFinished) of\n           (Unit,True) -> do\n             -- this optimisation is absolutely critical to implement dense\n             -- layers efficiently (at least with TF 1.3). (about 10x performance increase)\n             x' <- (rec (a :* b :* Unit) x)\n             return $ inflate2 (MatMul Unit (satMul n a) b c (flatten2 x') y)\n           _ -> MatMul (n :* s0) a b c  <$> (rec (s0 .+. a :* b :* Unit) x) <*> (rec (s0 .+. b :* c :* Unit) y)\n      BinOp op s0 s1 t1 s2 t2 x y -> knownTyp t1 $ knownTyp t2 $ do\n        BinOp op (n :* s0) s1 t1 s2 t2 <$> (rec (s0 .+. s1) x) <*> (rec (s0 .+. s2) y)\n      UnOp op s0 x -> UnOp op (n :* s0) <$> (rec (s0 .+. unopInputShape op) x)\n      Gather is s0 m s1 x ix -> do\n        xFinished <- finished x\n        case (s0,xFinished) of -- this optimisation is important to get efficient embeddings (???)\n          (Unit,True) -> Gather (n :* is) Unit m s1 x <$> (rec is ix)\n          _ -> Gather is (n :* s0) m s1 <$> (rec (s0 .+. m :* s1) x) <*> (rec (s0 .+. is) ix)\n      Transpose s0 t x -> Transpose (n :* s0) (PermSkip t) <$> (rec s0 x)\n      ReshapeFrom s0 x -> reshapeFrom (n :* s0) <$> (rec s0 x)\n      Concat s0 s1 xs -> do\n        Concat (n :* s0) s1 <$> hTraverse (\\(Catable m x) -> Catable m <$> (rec (s0 .+. m :* s1) x)) xs\n      Convolution bs@(Sat) inChans outChans filterShape s0 x filters -> do\n        filtersFinished <- finished filters\n        xFinished <- finished x\n        case (filtersFinished,xFinished) of\n          (True,_) ->\n            prodAssocS n bs (productS (s0 *: inChans))  #>\n            prodAssocS n bs (productS (s0 *: outChans)) #>\n            knownSShape (s0 *: inChans)                 ?> do\n              x' <- (rec typeSShape x)\n              return $ reshapeFrom (satMul n bs :* s0 *: outChans) \n                      (Convolution (satMul n bs) inChans outChans filterShape s0 (reshapeAuto x') filters)\n          (_,True) ->\n            knownSShape (filterShape .+. inChans :* outChans :* Unit) ?>\n            knownSShape (bs :* s0 .+. outChans :* Unit) ?> do\n              filters' <- rec typeSShape filters\n              return $ transposeN' $\n                reshapeProven (ANat bs !:* AShape s0 *:! (ANat outChans :*: ANat n))\n                              ((ANat bs !:* AShape s0 *:! ANat outChans) *:! ANat n) $\n                Convolution bs inChans (outChans `satMul` n) filterShape s0 x $\n                reshapeProven ((AShape filterShape :++: (ANat inChans !:* Single (ANat outChans))) *:! ANat n)\n                              (AShape filterShape :++: ANat inChans !:* Single (ANat outChans :*: ANat n)) $\n                transposeN $ filters'\n\n          _ -> error \"broadcast on both convolution filter and data not implemented\"\n      _ -> error \"protoBroadcast: unhandled case\" \n  where simpleBC :: G (T (n ': s) t)\n        simpleBC = appRUnit @s #>\n                   return (DirectBroadcast Unit (n :* Unit) s Unit tensor)\n\ninversePerm :: Permutation a b -> Permutation b a\ninversePerm PermId = PermId\ninversePerm (PermSkip x) = PermSkip (inversePerm x)\ninversePerm PermSwap = PermSwap\ninversePerm (PermTrans x y) = PermTrans (inversePerm y) (inversePerm x)\n\npermToFun :: Permutation s t -> Integer -> Integer\npermToFun = \\case\n  PermId -> \\x -> x\n  PermTrans a b -> permToFun b . permToFun a\n  PermSwap -> \\case\n    0 -> 1\n    1 -> 0\n    x -> x\n  PermSkip p -> \\case\n    0 -> 0\n    x -> permToFun p (x-1) + 1\n\nreshapeAuto :: forall s s0 t. KnownShape s0 => Product s ~ Product s0 => T s0 t -> T s t\nreshapeAuto = reshapeFrom typeSShape\n\nreshapeProven :: forall s s0 t n. ShapeX s0 n -> ShapeX s n -> T s0 t -> T s t\nreshapeProven s1 s2 = case decideProductEq s1 s2 of\n                        Refl -> knownSShape (exprSShape s1) ?> reshapeAuto\n\nreshapeTo :: forall s s0 t proxy. KnownShape s0=> Product s ~ Product s0 => proxy s -> T s0 t -> T s t\nreshapeTo _ = reshapeAuto\n\nreshapeFrom :: forall s s0 t. Product s ~ Product s0 => SShape s0 -> T s0 t -> T s t\nreshapeFrom _ (ReshapeFrom s1 x) = ReshapeFrom s1 x -- avoid reshaping over and over\nreshapeFrom s0 x = ReshapeFrom s0 x\n\n\ntype BatchedPlaceholders n ps = Placeholders (BPH n ps)\ntype BPH n ps = (Ap (FMap (ConsSh n)) ps)\n\n-- | Batch the model (adding one dimension).\nmapPlaceHolders :: forall batchSize shapesAndTypes resShapesAndTypes.\n    (KnownNat batchSize, KnownLen shapesAndTypes, KnownLen resShapesAndTypes, All KnownPlaceholder shapesAndTypes, All KnownPlaceholder resShapesAndTypes)\n  => Unique\n  -> Bool\n  -> (Placeholders shapesAndTypes -> Placeholders resShapesAndTypes)\n  -> BatchedPlaceholders batchSize shapesAndTypes -> (BatchedPlaceholders batchSize resShapesAndTypes)\nmapPlaceHolders u varyNoise f xs = broadcastPlacehoders @batchSize typeSList (f (unbroadcastPlacehoders @batchSize typeSList xs))  where\n    unbroadcastPlacehoders :: forall n r. KnownNat n => SList r -> BatchedPlaceholders n r -> Placeholders r\n    unbroadcastPlacehoders Unit Unit = Unit\n    unbroadcastPlacehoders (_ :* ss) (PHT x :* xs') = PHT (Unbroadcast batchSize u x) :* unbroadcastPlacehoders @n ss xs'\n      where batchSize = natSat @n\n\n    broadcastPlacehoders :: forall n r. All KnownPlaceholder r => KnownNat n => SList r -> Placeholders r -> (BatchedPlaceholders n r)\n    broadcastPlacehoders Unit Unit = Unit\n    broadcastPlacehoders (_ :* ss) (PHT x :* xs) =\n      let x' = BroadcastT (Just u) varyNoise (natSat @n) typeSShape x\n          xs' = broadcastPlacehoders @n ss xs\n      in (PHT x' :* xs') \n\n----------------------------------------------------------------\n-- Here start helper functions\n\npermN :: SList s -> Permutation (n ': s) (s ++ '[n])\npermN Unit = PermId\npermN ((:*) _n s) = PermSwap `PermTrans` PermSkip (permN s)\n\npermN01 :: SList s -> Proxy m -> Proxy n -> Permutation (s ++ [m,n]) (s ++ [n,m])\npermN01 Unit _ _ = PermSwap\npermN01 ((:*) _n s) m n = PermSkip (permN01 s m n)\n\n\n-- | Transposition. See the type for the permutation of dimensions.\ntransposeN :: ∀ s n t. KnownNat n => KnownShape s => T (n ': s) t -> T (s ++ '[n]) t\ntransposeN  = doTranspose typeSShape (permN (typeSList @s))\n\n-- | Transposition. See the type for the permutation of dimensions.\ntransposeN' :: ∀ s n t. KnownNat n => KnownShape s => T (s ++ '[n]) t -> T (n ': s) t\ntransposeN' = doTranspose (typeSShape @s *: (natSat @n)) (inversePerm (permN (typeSList @s)))\n\n\n-- | Transposition. See the type for the permutation of dimensions.\ntransposeN01 :: ∀ s m n t. KnownNat n => KnownNat m => KnownShape s => T (s ++ [m,n]) t -> T (s ++ [n,m]) t\ntransposeN01 = doTranspose (typeSShape @s .+. typeSShape @'[m,n]) (permN01 (typeSList @s) (Proxy @m) (Proxy @n))\n\n\n-- | Transposition. See the type for the permutation of dimensions.\ntranspose01 :: ∀ s m n t. KnownNat n => KnownNat m => KnownShape s => T (m ': n ': s) t -> T (n ': m ': s) t\ntranspose01 = doTranspose typeSShape PermSwap\n\ndoTranspose :: SShape s0 -> Permutation s0 s -> T s0 t -> T s t\ndoTranspose _  p (Transpose sh' q x) = doTranspose sh' (PermTrans q p) x\ndoTranspose sh p x = Transpose sh p x\n\n\n-- | Concatenate tensors with explicit shapes.\nconcatT' :: ∀ s0 d1 d2 s1 t. KnownTyp t =>\n    SShape s0 -> Sat KnownNat d1 -> Sat KnownNat d2 -> SShape s1 -> T (s0 ++ (d1 ': s1)) t -> T (s0 ++ (d2 ': s1)) t -> T (s0 ++ ((d1+d2) ': s1)) t\nconcatT' s0 d1@(Comp Dict) d2@(Comp Dict) s1 x y = Concat s0 s1 (Catable d1 x :* Catable d2 y :* Unit)\n\n\n-- | Reshape a tensor so that the first dimension is expanded into two.\ninflate2 :: ∀ m n s t. KnownTyp t => (KnownNat m, KnownNat n, KnownShape s) => Tensor (m*n ': s) t -> Tensor (m ': n ': s) t\ninflate2 = prodAssoc @m @n @(Product s) #> reshape\n\n\n-- | Reshape a tensor so that the first two dimensions are collapsed\nflatten2 :: ∀ m n s t. KnownTyp t => (KnownNat m, KnownNat n, KnownShape s) => Tensor (m ': n ': s) t -> Tensor (m*n ': s) t\nflatten2 = prodAssoc @m @n @(Product s) #> reshape\n\n\nreshape :: ∀ s2 s1 t. KnownShape s1 => KnownShape s2 => Product s1 ~ Product s2 => Tensor s1 t -> Tensor s2 t\nreshape = reshapeAuto\n\n-- | range[i] = i\nrange :: forall n w. KnownNat n => KnownBits w => T '[n] ('Typ 'Int w)\nrange = T (Range (natSat @n))\n"
  },
  {
    "path": "TypedFlow/Haskell.hs",
    "content": "{-|\nModule      : TypedFlow.Haskell\nDescription : Generation of computation graph using tensorflow haskell. \nCopyright   : (c) Jean-Philippe Bernardy, 2017\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n\n-}\n\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFoldable #-}\n{-# LANGUAGE DeriveFunctor #-}\n{-# LANGUAGE DeriveTraversable #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE GADTs #-}\n{-# LANGUAGE GeneralizedNewtypeDeriving #-}\n{-# LANGUAGE LambdaCase #-}\n{-# LANGUAGE MagicHash #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# LANGUAGE OverloadedStrings #-}\n{-# LANGUAGE PatternSynonyms #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE StandaloneDeriving #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeInType #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE UndecidableSuperClasses #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n\nmodule TypedFlow.Haskell where\n\nimport Data.Type.Equality\nimport Data.List (genericReplicate)\nimport GHC.TypeLits\nimport Control.Monad.State\nimport TypedFlow.Types\nimport TypedFlow.Types.Proofs\nimport TypedFlow.Abstract (newId, permToFun, unopInputShape)\nimport TypedFlow.Memo\nimport System.Mem.StableName\nimport System.IO.Unsafe\n\nimport qualified Data.Int as Backend\n\nimport qualified TensorFlow.Core        as Backend\nimport qualified TensorFlow.GenOps.Core as BackCore\nimport qualified TensorFlow.Minimize    as Backend\nimport qualified TensorFlow.Ops         as Backend\nimport qualified TensorFlow.NN          as Backend\n-- import qualified TensorFlow.Variable    as Backend\nimport qualified TensorFlow.Tensor\n\nimport qualified Data.IntMap as IM\nimport Data.IntMap (IntMap)\n\ntype BackendShape = BackendTensor ('Typ 'Int 'B32)\ntype BackendTensor t = Backend.Tensor Backend.Build (HaskType t)\ntype BackendVariable t = Backend.Tensor Backend.Ref (HaskType t)\ntype BackendTensorType t = Backend.TensorType (HaskType t)\n\nshapeFromType :: ∀ (s :: Shape). KnownShape s => BackendShape\nshapeFromType = shapeVector (typeSShape @s)\n\n-- | Show a shape, but \"None\" is replaced by \"-1\"\nshapeVector :: forall (s::Shape) proxy. All KnownNat s => SList' proxy s -> BackendShape\nshapeVector s = shapeFromList (shapeToList'' s)\n\npermToTensor :: SShape s -> Permutation s t -> Backend.Tensor Backend.Build Backend.Int32\npermToTensor s p = Backend.vector (map (fromInteger . permToFun p) [0.. sListLength s])\n\nshapeFromList :: [Integer] -> BackendShape\nshapeFromList = Backend.vector . map convertNone\n\nshowShapeLen :: ∀ (s::Shape). KnownLen s => Backend.Int32\nshowShapeLen = fromIntegral (listTypeLen @ s)\n\nconvertNone :: Num a => Integer -> a\nconvertNone n = (if n == 514229 then (-1) else fromIntegral n)\n\n-- runWithFeeds\n\ndata BT (s :: Shape) (t :: Typ) where\n  BT :: forall s t. (BackendTensor t) -> BT s t\n\ndata HState = HState {genVars :: IntMap Var\n                     ,genPureTable :: SNMap22 Shape Typ T BT\n                     -- alternative: use tensorRefFromName and make this closer to the python backed.\n                     }\n\ntype BM a = Backend.BuildT (StateT HState (State GState)) a\n\ndata Var = forall s t v. TensorFlow.Tensor.TensorKind v => Var (SShape s) (STyp t) (Backend.Tensor v (HaskType t))\n\ninitializedVariable :: forall s a. KnownShape s => KnownTyp a => T s a -> BM (Ref s a)\ninitializedVariable initVal = do\n  BT i <- interpretPure initVal\n  x <- lift (lift newId)\n  v <- backendTensor (typeSTyp @a) $ Backend.initializedVariable i\n  let var = (Var (typeSShape @s) (typeSTyp @a) v)\n  lift (modify $ \\HState{..} -> HState {genVars = IM.insert (fromIntegral x) var genVars,..})\n  return (Ref (fromIntegral x) typeSShape typeSTyp )\n\nplaceholder :: forall s a. SShape s -> STyp a -> BM (Ref s a)\nplaceholder s t = do\n  x <- lift (lift newId)\n  ph <- backendTensor t $ Backend.placeholder (Backend.Shape (map convertNone $ shapeToList' s))\n  let var = (Var s t ph)\n  lift (modify $ \\HState{..} -> HState {genVars = IM.insert (fromIntegral x) var genVars,..})\n  return (Ref (fromIntegral x) s t )\n\ninterpGen :: Gen a -> BM a\ninterpGen (GPReturn x) = return x\ninterpGen (GPVariable _trainable _name initVal) = initializedVariable initVal\ninterpGen (GPPlaceholder s t _name) = placeholder s t\ninterpGen (GPModify _ _) = error \"GPModify: TODO\"\ninterpGen (GPState f) = lift (lift (state f))\ninterpGen (GPBind a b) = do x <- interpGen a\n                            interpGen (b x)\n\nlistProxyLen :: forall proxy s. KnownLen s => proxy s -> Integer\nlistProxyLen _ = listTypeLen @s\n\n-- genDistr :: forall s s0 t. KnownTyp t => Distribution s t -> SShape s0 -> SShape s -> DOC\n-- genDistr d sh s1 = case d of\n--   TruncatedNormalD stddev -> funcall \"tf.truncated_normal\"\n--     [showSShape (sh .+. s1), named \"stddev\" (float stddev), named \"dtype\" (showTyp @t)]\n--   UniformD low high -> funcall \"tf.random_uniform\" [showSShape (sh .+. s1)\n--                                 ,named \"minval\" (float low)\n--                                 ,named \"maxval\" (float high)\n--                                 ,named \"dtype\" (showTyp @t)]\n--   OrthogonalD ->\n--     funcall' (funcall \"tf.orthogonal_initializer\" [named \"dtype\" (showTyp @t)]) [named \"shape\" (showSShape (sh .+. s1))]\n\n\nknownNumeric :: forall t k. KnownNumeric t => (KnownTyp t => Num (HaskType t) => Backend.OneOf '[Backend.Int32, Float, Double] (HaskType t) => k) -> k\nknownNumeric = knownNumeric' (typeSTyp @t)\n\nknownNumeric' :: forall t k. KnownNumeric t => STyp t -> (KnownTyp t => Num (HaskType t) => Backend.OneOf '[Backend.Int32, Float, Double] (HaskType t) => k) -> k\nknownNumeric' (STyp tk tb Refl) k = case tk of\n  SFloat -> case tb of\n    SB32 -> k\n    SB64 -> k\n  SBool -> error \"TFNumeric bug\"\n  SInt -> case tb of\n    SB32 -> k\n    SB64 -> error \"missing in tensorflow: int64 is not supported in matmul T_T\"\n\nknownFloatingB :: forall t k. (KnownTyp t, TypKind t ~ 'Float) => (Backend.OneOf '[Float, Double] (HaskType t) => k) -> k\nknownFloatingB k = case bitsVal @(TypBits t) of\n    SB32 -> k\n    SB64 -> k\n\nknownInt :: forall t k. (KnownTyp t, TypKind t ~ 'Int) => (Backend.OneOf '[Backend.Int32, Backend.Int64] (HaskType t) => k) -> k\nknownInt k = case bitsVal @(TypBits t) of\n    SB32 -> k\n    SB64 -> k\n\nbackendTensor :: STyp t ->  (Backend.TensorType (HaskType t) => k) -> k\nbackendTensor (STyp SFloat SB32 Refl) k = k\nbackendTensor (STyp SInt SB64 Refl) k = k\nbackendTensor (STyp SBool _ Refl) k = k\nbackendTensor (STyp SFloat SB64 Refl) k = k\nbackendTensor (STyp SInt SB32 Refl) k = k\n\nbackendTensor' :: forall t k proxy. KnownTyp t => proxy t -> (Backend.TensorType (HaskType t) => k) -> k\nbackendTensor' _ = backendTensor (typeSTyp @t)\n\n\nrunUnOp :: forall s s1 t s2 u. KnownTyp u => KnownTyp t => BackendTensorType u => SShape s -> UnOp s1 t s2 u -> BT (s++s1) t -> BT (s++s2) u\nrunUnOp sL op (BT x) = backendTensor (typeSTyp @t) $ case op of\n  SliceOp _ sR lo hi -> BT $ BackCore.slice x\n    (shapeFromList (replicate (sListLen  sL) 0 ++ [lo] ++ replicate (sListLen sR) 0))\n    (shapeFromList (shapeToList' sL ++ [hi-lo] ++ (shapeToList' sR)))\n  Axis1Op aop -> case aop of\n    (ArgMax _ _) -> knownNumeric @t $ knownInt @u $ BT $ BackCore.argMax x (Backend.scalar sLLen)\n    (OneHot _) -> knownNumeric @u $ knownInt @t $  BT $ Backend.oneHot x (Backend.scalar sLLen) (Backend.scalar 1) (Backend.scalar 0)\n    ReduceOp _ _sR rop -> knownNumeric @t $ case rop of\n      Max -> BT $ BackCore.max x redindices\n      Min -> BT $ BackCore.min x redindices\n      Sum -> BT $ Backend.sum x redindices\n      Mean -> BT $ Backend.mean x redindices\n     where redindices = (Backend.vector [fromIntegral (sListLen sL) :: Backend.Int32 ])\n  StopGradient -> BT $ BackCore.stopGradient x\n  Cast -> BT $ Backend.cast x\n  (Num1Op numop) -> knownNumeric @t $ case numop of\n    Square -> BT (Backend.mul x x)\n    Negate -> BT (Backend.neg x)\n    Sign -> BT (Backend.sign x)\n    Abs -> BT (Backend.abs x)\n    FloorMod -> BT (Backend.floorMod x)\n  Float1Op flop -> knownFloatingB @t $ knownFloating @(TypBits u) $ knownFloatingB @u $ case flop of\n     Tanh -> BT (BackCore.tanh x)\n     Sin -> BT (BackCore.sin x)\n     Exp -> BT (BackCore.exp x)\n     Sigmoid -> BT (BackCore.sigmoid x)\n     Relu -> BT (BackCore.relu x)\n     Floor -> BT (BackCore.floor x)\n     Round -> BT (BackCore.round x)\n     Cos -> BT (BackCore.cos x)\n     Log -> BT (BackCore.log x)\n     Asin -> BT (BackCore.asin x)\n     Acos -> BT (BackCore.acos x)\n     Sinh -> BT (BackCore.sinh x)\n     Cosh -> BT (BackCore.cosh x)\n     Asinh -> BT (BackCore.asinh x)\n     Acosh -> BT (BackCore.acosh x)\n     Atan -> BT (BackCore.atan x)\n     Atanh -> BT (BackCore.atanh x)\n     Sqrt -> BT (BackCore.sqrt x)\n     HardSigmoid -> error \"Haskell: no hard sigmoid defined yet\"\n     ClipByValue lo hi -> BT $ BackCore.clipByValue x (Backend.scalar $ realToFrac lo) (Backend.scalar $ realToFrac hi)\n  Diag _ -> BT $ BackCore.batchMatrixDiag x\n where sLLen = fromIntegral (sListLen sL) :: Backend.Int32\n\ninterpretPure :: forall s t. KnownTyp t => KnownShape s => T s t -> BM (BT s t)\ninterpretPure x = do\n  let sn = unsafePerformIO $ makeStableName x\n  mv <- snMap22Lookup sn <$> lift (gets genPureTable)\n  case mv of\n    Just v -> return v\n    Nothing -> do\n      e  <- interpretPure' (\\s x' -> knownSShape s $ interpretPure x') typeSShape x\n      lift $ modify (\\g -> g {genPureTable = (snMap22Insert (KV sn e)) (genPureTable g)})\n      return e\n\ninterpNilOp :: forall s t. Backend.TensorType (HaskType t) => NilOp s t -> BM (BT s t)\ninterpNilOp = \\case\n  Constant c -> return $ BT $ Backend.scalar c\n  Range n@Sat -> knownNumeric @t $ return $\n    let start,limit,delta :: HaskType t\n        start = 0\n        limit = fromIntegral $ natVal n\n        delta = 1\n    in BT $ Backend.range (Backend.scalar start) (Backend.scalar limit) (Backend.scalar delta)\n  Variable (Ref r sr tr) -> do\n     tbl <- lift (gets genVars)\n     case IM.lookup r tbl of\n       Just (Var sx tx x) -> case (testEq sx sr, testEq tx tr) of\n          (Just Refl, Just Refl) -> return (BT (Backend.expr x))\n          _ -> error \"panic: variable does not have the expected type\"\n       _ -> error \"panic: variable not found\" \n\ninterpretPure' :: forall s t. KnownTyp t => (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> BM (BT s' t')) -> SShape s -> T s t -> BM (BT s t)\ninterpretPure' rec sR = knownSShape sR $ backendTensor (typeSTyp @t) $ \\case\n  Unbroadcast{} -> error \"broadcasting operation did not complete!\"\n  DirectBroadcast s0 s1 s2 s3 x -> do\n    BT recx <- rec (s0 .+. s2) x\n    let expandedShape = shapeFromList\n                          (concat [shapeToList' s0, genericReplicate (sListLength s1) 1\n                                  ,shapeToList' s2, genericReplicate (sListLength s3) 1 ])\n        targetShape = shapeFromList sR\n    return $ BT $ BackCore.broadcastTo (Backend.reshape recx expandedShape) targetShape\n   --  Noise noiseId s0 s1 x -> do\n   --    return $ (genDistr x s0 s1) <+> (text \"# \" <> integer noiseId)\n  T op -> interpNilOp op\n  Where c x y -> do\n    BT rc <- rec typeSShape c\n    BT rx <- rec typeSShape x\n    BT ry <- rec typeSShape y\n    return $ BT $ BackCore.select rc rx ry\n  UnOp operation s0 x -> do\n    recx <- rec (s0 .+. unopInputShape operation) x\n    return (runUnOp s0 operation recx)\n  MatMul s0 a b c x y  -> do\n    BT recx <- rec (s0 .+. a :* b :* Unit) x\n    BT recy <- rec (s0 .+. b :* c :* Unit) y\n    return $ knownNumeric @t $ BT $ BackCore.batchMatMul recx recy\n  BinOp operation s0 s1 t s2 u x y -> knownSShape s0 $ knownSShape s1 $ knownSShape s2 $ knownProduct' s0 $ do\n   BT recx <- rec (s0 .+. s1) x\n   BT recy <- rec (s0 .+. s2) y\n   let reshx = backendTensor t $ Backend.reshape recx (shapeVector (satProd s0 :* s1))\n       reshy = backendTensor u $ Backend.reshape recy (shapeVector (satProd s0 :* s2))\n   return $ case operation of\n     Simple2Op sop  -> case sop of\n        Add -> knownNumeric @t $ BT $ Backend.add recx recy\n        Divide -> knownNumeric @t $ BT $ BackCore.div recx recy\n        Equal -> backendTensor u $ BT $ Backend.equal recx recy\n        Subtract -> knownNumeric @t $ BT $ Backend.sub recx recy\n        Multiply -> knownNumeric @t $ BT $ Backend.mul recx recy\n        Minimum -> knownNumeric @t $ BT $ BackCore.minimum recx recy\n        Maximum -> knownNumeric @t $ BT $ BackCore.maximum recx recy\n        LessThan ->  knownNumeric' u $ BT $ BackCore.less recx recy\n     -- WTF moment: the arguments do not seem to be in the same order in python as in haskell\n     -- python: https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits\n     -- haskell: https://tensorflow.github.io/haskell/haddock/tensorflow-core-ops-0.2.0.0/TensorFlow-GenOps-Core.html#v:sparseSoftmaxCrossEntropyWithLogits\n     SparseSoftmaxCrossEntropyWithLogits -> case t of\n        STyp SInt SB32 Refl -> knownFloatingB @t $ BT $ fst $ BackCore.sparseSoftmaxCrossEntropyWithLogits reshy reshx\n     SoftmaxCrossEntropyWithLogits -> knownFloatingB @t $ BT $ fst $ BackCore.softmaxCrossEntropyWithLogits reshy reshx\n     -- SigmoidCrossEntropyWithLogits -> knownFloatingB @t $ BT $ Backend.sigmoidCrossEntropyWithLogits recy recx -- type is not as general as necessary\n  ReshapeFrom s t -> do\n    BT rt <- rec s t\n    return $ BT $ BackCore.reshape rt (shapeVector sR)\n  Concat s0 s1 xs -> do\n    let go :: forall s0 s1 ns. SShape s0 -> SShape s1 -> NP (Catable s0 s1 t) ns -> BM [BackendTensor t]\n        go _ _ Unit = return []\n        go s0' s1' (Catable n y :* ys) = do\n          BT y' <- rec (s0' .+. n :* s1') y\n          (y' :) <$> go s0' s1' ys\n    rxs <- go s0 s1 xs\n    return $ BT $ Backend.concat (Backend.scalar (fromIntegral (sListLength s0))) rxs\n  Transpose s p x -> do\n    BT rx <- rec s x\n    return $ BT $ Backend.transpose rx (permToTensor s p)\n --  Gather indexShape s0 m s1 x ix -> do\n --    rx <- rec (s0 .+. ((:*) m s1)) x\n --    rix <- rec indexShape ix\n --    return (func \"tf.gather\" [rx, rix] [])\n --  GatherND containerShape elementShape indexShape x ix -> do\n --    rx <- rec (containerShape .+. elementShape) x\n --    rix <- rec (indexShape *: (sListLenAsNat containerShape)) ix\n --    return (func \"tf.gather_nd\" [rx, rix] [])\n  Convolution bs inChans outChans filterShape s0 x filters -> do\n    BT recx <- rec (bs :* (s0 *: inChans)) x\n    BT recFilters <- rec (filterShape .+. inChans :* outChans :* Unit) filters\n    case filterShape of\n       _width :* _height :* Unit ->\n          return $ BT $ knownFloatingB @t $ BackCore.conv2D recx recFilters\n       _ -> error \"TypedFlow Haskell backend: convolution on an unsupported number of dims\"\n --  Pool bs window typ numChans outSpatial x -> do\n --     rx <- rec ((:*) bs (zipWithMulSShapes window outSpatial .+. (:*) numChans Unit)) x\n --     return (func \"tf.nn.pool\"\n --                  [rx, showSShape window, typ', text (show (\"SAME\" :: String))]\n --                  [(\"strides\", showSShape window)])\n --   where typ' = text $ (show $ case typ of MaxPool -> \"MAX\"; AvgPool -> \"AVG\" :: String)\n -- -- where rec :: forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> DOC\n -- --       rec = generatePure' \n\n"
  },
  {
    "path": "TypedFlow/Layers/Core.hs",
    "content": "{-|\nModule      : TypedFlow.Layers.Core\nDescription : Core layers and combinators.\nCopyright   : (c) Jean-Philippe Bernardy, 2017\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n-}\n{-# LANGUAGE CPP #-}\n#if __GLASGOW_HASKELL__ >= 806\n{-# LANGUAGE NoStarIsType #-}\n#endif\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE ViewPatterns #-}\n{-# LANGUAGE TypeInType #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFoldable #-}\n{-# LANGUAGE DeriveFunctor #-}\n{-# LANGUAGE DeriveTraversable #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE GADTs #-}\n{-# LANGUAGE MagicHash #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE StandaloneDeriving #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-# LANGUAGE PatternSynonyms #-}\n\nmodule TypedFlow.Layers.Core\n  (\n    -- * Dense\n    DenseP(..), dense, (#),\n    -- * Dropout\n    DropProb(..), mkMask, mkDropout, mkDropouts,\n    -- * Embedding\n    EmbeddingP(..), embedding, \n    -- * Convolutional\n    ConvP(..), conv, conv', {-convValid,-} maxPool1D, maxPool2D,\n    glu\n  )\n\nwhere\nimport Prelude hiding (RealFrac(..))\nimport GHC.TypeLits\nimport TypedFlow.TF\nimport TypedFlow.Types\nimport TypedFlow.Types.Proofs\nimport TypedFlow.Abstract\nimport Control.Monad.State (gets)\nimport Data.Monoid ((<>))\n---------------------\n-- Linear functions\n\n\n-- | A dense layer is a linear function form a to b: a transformation matrix and a bias.\ndata DenseP t a b = DenseP {denseWeights :: Tensor '[a,b] t\n                           ,denseBiases  :: Tensor '[b] t}\n\n-----------------------\n-- Feed-forward layers\n\n-- | Parameters for the embedding layers\nnewtype EmbeddingP numObjects embeddingSize t = EmbeddingP (Tensor '[numObjects, embeddingSize] t)\n\ninstance (KnownNat numObjects, KnownTyp b, KnownNat embeddingSize) => KnownTensors (EmbeddingP numObjects embeddingSize b) where\n  travTensor f s (EmbeddingP p) = EmbeddingP <$> travTensor f s p\n\ninstance (KnownNat numObjects, KnownBits b, KnownNat embeddingSize) => ParamWithDefault (EmbeddingP numObjects embeddingSize ('Typ 'Float b)) where\n  defaultInitializer = EmbeddingP <$> (noise $ UniformD (-0.05) 0.05)\n\ninstance (KnownNat numObjects, KnownBits b, KnownNat embeddingSize) => ParamWithDefault (EmbeddingP numObjects embeddingSize ('Typ 'Cmplx b)) where\n  defaultInitializer = EmbeddingP <$> (mkComplex <$> (noise $ UniformD (-0.05) 0.05) <*> (noise $ UniformD (-0.05) 0.05))\n\n-- | embedding layer\nembedding :: ∀ embeddingSize numObjects t. KnownNat embeddingSize => KnownNat numObjects =>\n             EmbeddingP numObjects embeddingSize t -> Tensor '[] Int32 -> Tensor '[embeddingSize] t\nembedding (EmbeddingP param) input = gather param input\n\n\n\ninstance (KnownNat a, KnownNat b, KnownTyp t) => KnownTensors (DenseP t a b) where\n  travTensor f s (DenseP x y) = DenseP <$> travTensor f (s<>\"_w\") x <*> travTensor f (s<>\"_bias\") y\n\ninstance (KnownNat n, KnownNat m, KnownFloat b) => ParamWithDefault (DenseP b n m) where\n  defaultInitializer = DenseP <$> glorotUniform <*> (noise $ TruncatedNormalD 0.1)\n\n-- | Dense layer (Apply a linear function)\n(#), dense :: ∀m n t. KnownNat n => KnownNat m => KnownNumeric t => DenseP t n m -> Tensor '[n] t -> Tensor '[m] t\n(DenseP weightMatrix bias) # v = (weightMatrix ∙ v) + bias\n\ndense = (#)\n\n-- | A drop probability. (This type is used to make sure one does not\n-- confuse keep probability and drop probability)\ndata DropProb = DropProb Float\n\n-- | Generate a dropout function. The mask applied by the returned\n-- function will be constant for any given call to mkDropout.  See\n-- 'noise' for the sampling behaviour.\nmkDropout :: forall s t. KnownShape s => KnownFloat t => DropProb -> Gen (Tensor s t -> Tensor s t)\nmkDropout d = (⊙) <$> mkMask d\n\n-- | Generate a 0-1 mask with given probability, suitable for dropout,\n-- or all ones if not in training phase. See 'noise' for the sampling\n-- behaviour.\nmkMask :: forall s t. KnownShape s => KnownFloat t => DropProb -> Gen (Tensor s t)\nmkMask (DropProb dropProb) = do\n  let keepProb = 1 - dropProb\n  let isTraining = genTrainingPlaceholder\n  r <- noise $ UniformD keepProb (1 + keepProb)\n  return $ if_ isTraining\n               (floor r ⊘ constant (knownAlgebraic @t $ realToFrac keepProb))\n               ones\n\nnewtype EndoTensor t s = EndoTensor (Tensor s t -> Tensor s t)\n\n-- | Generate a dropout function for an heterogeneous tensor vector.\nmkDropouts :: KnownFloat t => KnownLen shapes => All KnownShape shapes => DropProb -> Gen (HTV t shapes -> HTV t shapes)\nmkDropouts d = appEndoTensor <$> mkDropouts' typeSList where\n   mkDropouts' :: forall shapes t. KnownFloat t => All KnownShape shapes =>\n                  SList shapes -> Gen (NP (EndoTensor t) shapes)\n   mkDropouts' Unit = return Unit\n   mkDropouts' (_ :* rest) = do\n     x <- mkDropout d\n     xs <- mkDropouts' rest\n     return (EndoTensor x :* xs)\n\n   appEndoTensor :: NP (EndoTensor t) s -> HTV t s -> HTV t s\n   appEndoTensor Unit Unit = Unit\n   appEndoTensor (EndoTensor f :* fs) (F x :* xs) = F (f x) :* appEndoTensor fs xs\n\n\n------------------------\n-- Convolutional layers\n\ndata ConvP t outChannels inChannels filterSpatialShape\n  = ConvP (T (filterSpatialShape ++ '[inChannels,outChannels]) t)\n          (T '[outChannels] t)\n\ninstance (KnownNat outChannels,KnownNat inChannels, KnownShape filterSpatialShape, KnownFloat t) =>\n  ParamWithDefault (ConvP t outChannels inChannels filterSpatialShape) where\n  defaultInitializer = prodHomo @filterSpatialShape @'[inChannels, outChannels] #>\n                       prodAssoc @(Product filterSpatialShape) @inChannels @outChannels #>\n                       knownAppend @filterSpatialShape @'[inChannels,outChannels] ?>\n                       knownProduct @filterSpatialShape ?>\n                       ConvP <$> (reshape <$> i) <*> pure (knownAlgebraic @t (constant 0.1))\n    where i :: Gen (T '[Product filterSpatialShape*inChannels,outChannels] t)\n          i = knownProduct @filterSpatialShape ?> glorotUniform\n\ninstance (KnownNat outChannels,KnownNat inChannels, KnownShape filterSpatialShape, KnownAlgebraic t) =>\n  KnownTensors (ConvP t outChannels inChannels filterSpatialShape) where\n  travTensor f s (ConvP x y) = knownAppend @filterSpatialShape @'[inChannels,outChannels] ?>\n          (ConvP <$> travTensor f (s<>\"_filters\") x <*> travTensor f (s <> \"_biases\") y)\n\n-- | Size-preserving convolution layer\nconv' :: forall s outChannels filterSpatialShape inChannels t.\n               KnownShape s => KnownNat inChannels => KnownNat outChannels => KnownShape filterSpatialShape => KnownAlgebraic t\n            => Length filterSpatialShape <= 3\n            => Length filterSpatialShape ~ Length s\n            => ConvP t outChannels inChannels filterSpatialShape\n            -> T (s ++ '[inChannels]) t\n            -> T (s ++ '[outChannels]) t\nconv' (ConvP filters bias) input = mapTT @s (+bias) (convolution @outChannels @filterSpatialShape @inChannels @s input filters)\n\n\n\nconv :: forall outChannels filterSpatialShape inChannels s t.\n               KnownShape s => KnownNat inChannels => KnownNat outChannels => KnownShape filterSpatialShape => KnownAlgebraic t\n            => Length filterSpatialShape <= 3\n            => (Length filterSpatialShape + 1) ~ Length s -- The ranks must match, but not necessarily the dimensions\n            => (Last s ~ outChannels)\n            => ConvP t outChannels inChannels filterSpatialShape\n            -> T (Init s ++ '[inChannels]) t\n            -> T s t\nconv = initLast' @s #>\n       incrPos @(Length filterSpatialShape) #>\n       lengthInit (typeSList @s) #>\n       incrCong @(Length filterSpatialShape) @(Length (Init s)) #>\n       knownInit @s ?>\n       conv' @(Init s)\n\n\n-- -- | Convolution layers with no padding (applying the filter only on\n-- -- positions where the input is fully defined, aka \"VALID\" in\n-- -- tensorflow.)\n-- convValid :: forall outChannels filterSpatialShape inChannels s t.\n--                   ((1 + Length filterSpatialShape) ~ Length s,\n--                    Length filterSpatialShape <= 3,\n--                    KnownLen filterSpatialShape) -- the last dim of s is the batch size\n--           => ConvP t outChannels inChannels filterSpatialShape -- ^ Parameters\n--           -> T ('[inChannels] ++ AddSpatialDims s filterSpatialShape) ('Typ 'Float t) -- ^ input\n--           -> (T ('[outChannels] ++ s) ('Typ 'Float t))\n-- convValid (ConvP filters bias) input = convolutionValid input filters + bias\n\n-- | Gated Linear Unit\n-- See: Language Modeling with Gated Convolutional Networks\n-- https://arxiv.org/pdf/1612.08083.pdf\nglu :: forall n t. KnownBits t => KnownNat n => T '[n+n] ('Typ 'Float t) -> T '[n] ('Typ 'Float t)\nglu x = plusMono @n @n #> knownPlus @n @n ?>\n        let gate, h :: T '[n] ('Typ 'Float t)\n            gate = slice0 @0 @n x\n            h =  termCancelation @n @n #> slice0 @n @(n+n) x\n        in sigmoid gate ⊙ h\n"
  },
  {
    "path": "TypedFlow/Layers/RNN/Attention.hs",
    "content": "{-|\nModule      : TypedFlow.Layers.RNN.Attention\nDescription : Attention combinators to be used with RNN cells\nCopyright   : (c) Jean-Philippe Bernardy, 2018\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n-}\n\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE ViewPatterns #-}\n{-# LANGUAGE TypeInType #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFoldable #-}\n{-# LANGUAGE DeriveFunctor #-}\n{-# LANGUAGE DeriveTraversable #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE GADTs #-}\n{-# LANGUAGE MagicHash #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE StandaloneDeriving #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-# LANGUAGE PatternSynonyms #-}\n\nmodule TypedFlow.Layers.RNN.Attention (\n  -- * Attention mechanisms\n  -- ** Scoring functions\n  AttentionScoring,\n  multiplicativeScoring,\n  AdditiveScoringP(..), additiveScoring,\n  -- ** Attention functions\n  AttentionFunction,\n  uniformAttn,\n  luongAttention,\n  -- ** Attention combinators\n  attentiveWithFeedback\n  ) where\n\nimport Prelude hiding (RealFrac(..))\nimport GHC.TypeLits\nimport TypedFlow.TF\nimport TypedFlow.Types\nimport TypedFlow.Types.Proofs (appRUnit,(#>))\nimport TypedFlow.Layers.RNN.Base\n\n-- | An attention scoring function. This function should produce a\n-- score (between 0 and 1).\ntype AttentionScoring t keySize valueSize = \n  Tensor '[keySize] t -> Tensor '[valueSize] t -> Tensor '[] t\n\n-- | A function which attends to an external input. Typically a\n-- function of this type is a closure which has the attended input in\n-- its environment. This environment is interpreted as an associative\n-- memory form key to value.\ntype AttentionFunction t keySize valueSize =\n  T '[keySize] t -> T '[valueSize] t\n\n-- | @attnExample1 θ h st@ combines each element of the vector h with\n-- s, and applies a dense layer with parameters θ. The \"winning\"\n-- element of h (using softmax) is returned.\nuniformAttn :: ∀ valueSize m keySize t. KnownNat valueSize => KnownNat m => KnownFloat t\n       => AttentionScoring t keySize valueSize -- ^ scoring function\n       -> T '[] Int32 -- ^ length of the input\n       -> T '[m,valueSize] t -- ^ input (what we're attending to)\n       -> AttentionFunction t keySize valueSize\nuniformAttn score len hs key = c\n  where xx,α :: T '[m] t\n        xx = mapT (score key) hs\n        α = softmax0 (mask ⊙ xx)\n        c :: T '[valueSize] t\n        c = hs ∙ α\n        mask = cast (sequenceMask @m len) -- mask according to length\n\n-- | Add some attention to an RnnCell, and feed the attention vector to\n-- the next iteration in the rnn. (This follows the diagram at\n-- https://github.com/tensorflow/nmt#background-on-the-attention-mechanism\n-- commit 75aa22dfb159f10a1a5b4557777d9ff547c1975a).\nattentiveWithFeedback ::forall attSize cellSize inputSize w ss. KnownNat inputSize => KnownNat attSize => KnownLen ss =>\n  KnownTyp w =>\n  AttentionFunction w cellSize attSize ->\n  RnnCell w ss                   (T '[inputSize+attSize] w) (T '[cellSize] w) ->\n  RnnCell w ('[attSize] ': ss)   (T '[inputSize        ] w) (T '[attSize] w)\nattentiveWithFeedback attn cell = appRUnit @ss #> withFeedback (cell .-. timeDistribute attn)\n\n\n-- -- | LSTM for an attention model. The result of attention is fed to the next step.\n-- attentiveLstm :: forall attSize n x bs t. KnownNat bs =>\n--   AttentionFunction t bs n attSize ->\n--   LSTMP t n (x+attSize) ->\n--   RnnCell t '[ '[attSize,bs], '[n,bs], '[n,bs] ] (Tensor '[x,bs] (Flt t)) (Tensor '[attSize,bs] (Flt t))\n-- attentiveLstm att w = attentiveWithFeedback att (lstm w)\n\n\n-- | Luong attention function (following\n-- https://github.com/tensorflow/nmt#background-on-the-attention-mechanism\n-- commit 75aa22dfb159f10a1a5b4557777d9ff547c1975a).\n-- Essentially a dense layer with tanh activation, on top of uniform attention.\nluongAttention :: ∀ attnSize d m e w. KnownNat e => KnownNat d => KnownNat attnSize => KnownNat m => KnownFloat w\n  => Tensor '[d+e,attnSize] w     -- ^ weights for the dense layer\n  -> AttentionScoring w e d -- ^ scoring function\n  -> Tensor '[] Int32          -- ^ length of the input\n  -> T '[m,d] w         -- ^ inputs\n  -> AttentionFunction w e attnSize\nluongAttention w scoring lens hs_ ht = \n  let ct = uniformAttn scoring lens hs_ ht\n  in (tanh (w ∙ (concat0 ct ht)))\n\n-- | Multiplicative scoring function\nmultiplicativeScoring :: forall valueSize keySize t.\n  KnownFloat t => KnownNat valueSize => KnownNat keySize\n  => T [keySize,valueSize] t -- ^ weights\n  ->  AttentionScoring t keySize valueSize\nmultiplicativeScoring w dt h = ir · h\n  where ir :: T '[valueSize] t\n        ir = w ∙ dt\n\n\ndata AdditiveScoringP sz keySize valueSize t = AdditiveScoringP\n  (Tensor '[1,sz]          t)\n  (Tensor '[keySize, sz]   t)\n  (Tensor '[valueSize, sz] t)\n\ninstance (KnownNat n, KnownNat k, KnownNat v, KnownTyp t) => KnownTensors (AdditiveScoringP k v n t) where\n  travTensor f s (AdditiveScoringP x y z) = AdditiveScoringP <$> travTensor f (s<>\"_v\") x <*> travTensor f (s<>\"_w1\") y <*> travTensor f (s<>\"_w2\") z\ninstance (KnownNat n, KnownNat k, KnownNat v, KnownFloat t) => ParamWithDefault (AdditiveScoringP k v n t) where\n  defaultInitializer = AdditiveScoringP <$> glorotUniform <*> glorotUniform <*> glorotUniform\n\n-- | An additive scoring function. See https://arxiv.org/pdf/1412.7449.pdf\nadditiveScoring :: forall sz keySize valueSize t. KnownNat valueSize => KnownNat sz => KnownNat keySize => KnownFloat t =>\n  AdditiveScoringP sz keySize valueSize t -> AttentionScoring t valueSize keySize\nadditiveScoring (AdditiveScoringP v w1 w2) dt h =  r''\n  where w1h :: Tensor '[sz] t\n        w1h = w1 ∙ h\n        w2dt = w2 ∙ dt\n        z' :: Tensor '[sz] t\n        z' = tanh (w1h + w2dt)\n        r'' = z' · squeeze0 v\n\n"
  },
  {
    "path": "TypedFlow/Layers/RNN/Base.hs",
    "content": "{-|\nModule      : TypedFlow.Layers.RNN.Base\nDescription : RNN cells, layers and combinators.\nCopyright   : (c) Jean-Philippe Bernardy, 2017\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n-}\n\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE ViewPatterns #-}\n{-# LANGUAGE TypeInType #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFoldable #-}\n{-# LANGUAGE DeriveFunctor #-}\n{-# LANGUAGE DeriveTraversable #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE GADTs #-}\n{-# LANGUAGE MagicHash #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE StandaloneDeriving #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-# LANGUAGE PatternSynonyms #-}\n\nmodule TypedFlow.Layers.RNN.Base (\n  -- * Cell Combinators\n  RnnCell,\n  simpleRnn,\n  runCell, mkCell,\n  stackRnnCells, (.-.),\n  bothRnnCells, (.|.),\n  withBypass, withFeedback,\n  onStates,\n  -- * Rnn Combinators\n  Rnn,\n  runRnn,\n  stackRnns, (.--.),\n  bothRnns,(.++.),\n  -- * RNN unfolding functions\n  timeDistribute,\n  iterateCell,\n  iterateCellBackward,\n  iterateWithCull,\n  -- * Monad-like interface for cell construction\n  Component(..), bindC, returnC,\n  -- rnnBackwardsWithCull,\n  )\n\nwhere\nimport Prelude hiding (tanh,Num(..),Floating(..),floor)\nimport GHC.TypeLits\nimport TypedFlow.TF\nimport TypedFlow.Types\nimport TypedFlow.Types.Proofs\n-- import Data.Type.Equality\n-- import Data.Kind (Type,Constraint)\n\n-- | The RNN Component generalized monad. This can be used to build\n-- RNNs cells which do not follow the simple and usual \"stacking\"\n-- pattern. This is not a simple monad, because the indexing over\n-- states is non-uniform; see 'BindC'.\nnewtype Component t (states::[Shape]) a\n  = C {runC :: HTV t states -> (HTV t states , a)}\n-- Note: states are tensors only, because we need to index into them\n-- in the time dimension in iterateWithCull\n\ninstance Functor (Component t states) where\n  fmap = mapC\n\nmapC :: (a -> b) -> Component t s a -> Component t s b\nmapC f c = C $ \\s ->\n  let (s',x) = runC c s\n  in (s', f x)\n\n-- | Unit of the Component monad.\nreturnC :: a -> Component t '[] a\nreturnC x = C $ \\Unit -> (Unit,x)\n\n-- | Bind operation for Components. States are accumulated.\nbindC :: forall t s0 s1 a b. KnownLen s1\n  => Component t s0 a -> (a -> Component t s1 b) -> Component t (s1++s0) b\nbindC f g = C $ \\(hsplit @s1 -> (s1,s0)) -> \n  let (s0',x) = runC f s0\n      (s1',y) = runC (g x) s1\n  in (happ s1' s0',y)\n\n-- | A cell (one time-step) in an rnn. @state@ is the state propagated through time.\ntype RnnCell t states input output = input -> Component t states output\n\n-- | An rnn. @n@ is the length of the time sequence. @state@ is the state propagated through time.\ntype Rnn n b state input output = RnnCell b state (V n input) (V n output) \n\n-- | Run a cell\nrunCell :: RnnCell t states input output -> (HTV t states,input) -> (HTV t states, output)\nrunCell cell = uncurry (flip (runC . cell))\n\n-- | Run an RNN, using a tensor as input. @n@ is the length of the time sequence. \nrunRnn :: (KnownNat n,KnownShape s0, KnownShape s1, KnownTyp t1)\n       => Rnn n t2 states (T s1 t1) (T s0 t0)\n       -> (HTV t2 states, Tensor (n ': s1) t1)\n       -> (HTV t2 states, Tensor (n ': s0) t0)\nrunRnn l (s,x) =\n  let x' = unstack0 x\n      (s',y) = runCell l (s,x')\n  in (s',stack0 y)\n\n-- | Run an RNN composed of a single RNN cell.\nsimpleRnn :: KnownTyp t1 => KnownShape s1 => KnownShape s0 => KnownNat n\n          => RnnCell t2 states (T s1 t1) (T s0 t0)\n          -> (HTV t2 states, Tensor (n : s1) t1)\n          -> (HTV t2 states, Tensor (n : s0) t0)\nsimpleRnn = runRnn . iterateCell\n\n-- | Construct a cell from an arbitrary stateful function\nmkCell :: ((HTV t states,input) -> (HTV t states, output)) -> RnnCell t states input output\nmkCell cell = C . flip (curry cell)\n\n----------------------\n-- Lifting functions\n\n-- | Convert a pure function (feed-forward layer) to an RNN cell by\n-- ignoring the RNN state.\ntimeDistribute :: (a -> b) -> RnnCell t '[] a b\ntimeDistribute = constantOverSteps\n\n-- | Convert a pure function (feed-forward layer) to an RNN cell by\n-- ignoring the RNN state.\nconstantOverSteps :: (a -> b) -> RnnCell t '[] a b\nconstantOverSteps stateLess a = returnC (stateLess a)\n\n--------------------------------------\n-- Combinators\n\n-- | Compose two rnn layers. This is useful for example to combine\n-- forward and backward layers.\n(.--.),stackRnns :: forall s1 s2 a b c n bits. KnownLen s2\n  => Rnn n bits s1 a b -> Rnn n bits s2 b c -> Rnn n bits (s2 ++ s1) a c\nstackRnns = stackRnnCells\n\ninfixr .--.\n(.--.) = stackRnns\n\n-- | Compose two rnn layers in parallel.\nbothRnns,(.++.)  :: forall s1 s2 a b c n bits t.\n  KnownTyp t => KnownLen s1 => KnownLen s2 => KnownNat n\n  => KnownNat b => KnownNat c\n  => Rnn n bits s1 a (T '[b] t) -> Rnn n bits s2 a (T '[c] t) -> Rnn n bits (s2 ++ s1) a (T ('[b+c]) t)\nbothRnns f g x =\n  f x `bindC` \\y ->\n  g x `bindC` \\z ->\n  returnC (concat0 <$> y <*> z)\n\ninfixr .++.\n(.++.) = bothRnns\n\n-- | Apply a function on the cell state(s) before running the cell itself.\nonStates ::  (HTV t xs -> HTV t xs) -> RnnCell t xs a b -> RnnCell t xs a b\nonStates f cell x = C $ \\h -> do\n  runC (cell x) (f h)\n\n-- | Stack two RNN cells (LHS is run first)\nstackRnnCells, (.-.) :: forall s0 s1 a b c t. KnownLen s1\n  => RnnCell t s0 a b -> RnnCell t s1 b c -> RnnCell t (s1 ++ s0) a c\nstackRnnCells l1 l2 x = l1 x `bindC` l2\n(.-.) = stackRnnCells\n\n\n-- | Compose two rnn cells in parallel.\nbothRnnCells, (.|.) :: forall s0 s1 a b c t bits. KnownLen s0 => KnownLen s1\n  => KnownBits bits\n  => KnownNat b => KnownNat c\n  => RnnCell t s0 a (T '[b] (Flt bits))\n  -> RnnCell t s1 a (T '[c] (Flt bits))\n  -> RnnCell t (s1 ++ s0) a (T '[b+c] (Flt bits))\nbothRnnCells l1 l2 x  =\n  l1 x `bindC` \\y ->\n  l2 x `bindC` \\z ->\n  returnC (concat0 y z)\n\n(.|.) = bothRnnCells\n\n\n-- | Run the cell, and forward the input to the output, by\n-- concatenation with the output of the cell. This bypass is sometimes\n-- called a 'highway' in the literature.\nwithBypass :: forall x y t b s0. KnownNat x => KnownNat y => KnownLen s0\n  => KnownTyp t\n  => RnnCell b s0 (T '[x] t) (T '[y] t) -> RnnCell b s0 (T '[x] t) (T '[x+y] t)\nwithBypass cell x = appRUnit @s0 #>\n  cell x `bindC` \\y ->\n  returnC (concat0 x y)\n\n-- | Run the cell, and feeds its output as input to the next time-step\nwithFeedback :: forall outputSize inputSize (w :: Typ) ss.\n  KnownTyp w => KnownNat outputSize => KnownNat inputSize =>\n  RnnCell w ss                    (T '[inputSize+outputSize] w) (T '[outputSize] w) ->\n  RnnCell w ('[outputSize] ': ss) (T '[inputSize           ] w) (T '[outputSize] w)\nwithFeedback cell x = C $ \\(F prevoutputnVector :* s) -> \n  let (s',y) = runC (cell (concat0 x prevoutputnVector)) s\n  in  (F y :* s',y)\n\n---------------------------------------------------------\n-- RNN unfolding\n\n-- | Build a RNN by repeating a cell @n@ times.\niterateCell :: ∀ n state input output b.\n       (KnownNat n) =>\n       RnnCell b state input output -> Rnn n b state input output\niterateCell c x = C $ \\s -> chainForward (\\(t,y) -> runC (c y) t) (s,x)\n\n-- | Build a RNN by repeating a cell @n@ times. However the state is\n-- propagated in the right-to-left direction (decreasing indices in\n-- the time dimension of the input and output tensors)\niterateCellBackward :: ∀ n state input output b.\n       (KnownNat n) =>\n       RnnCell b state input output -> Rnn n b state input output\niterateCellBackward c x = C $ \\s -> chainBackward (\\(t,y) -> runC (c y) t) (s,x)\n\n-- | RNN helper\nchainForward :: ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (state , V n b)\nchainForward _ (s0 , VUnit) = (s0 , VUnit)\nchainForward f (s0 , x :** xs) = \n  let (s1,x') = f (s0 , x)\n      (sFin,xs') = chainForward f (s1 , xs)\n  in  (sFin,(x':**xs'))\n\n-- | RNN helper\nchainBackward :: ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (state , V n b)\nchainBackward _ (s0 , VUnit) = (s0 , VUnit)\nchainBackward f (s0 , (x:**xs)) =\n  let (s1,xs') = chainBackward f (s0,xs)\n      (sFin, x') = f (s1,x)\n  in (sFin,(x':**xs'))\n\n\n-- | RNN helper\nchainForwardWithState :: ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (V n b, V n state)\nchainForwardWithState _ (_s0 , VUnit) = (VUnit, VUnit)\nchainForwardWithState f (s0 , (x:**xs)) =\n  let (s1,x') = f (s0 , x)\n      (xs',ss) = chainForwardWithState f (s1 , xs)\n  in ((x':**xs'), (s1:**ss) )\n\n-- -- | RNN helper\n-- chainBackwardWithState ::\n--   ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (state , V n b, V n state)\n-- chainBackwardWithState _ (s0 , VUnit) = return (s0 , VUnit, VUnit)\n-- chainBackwardWithState f (s0 , (x:**xs)) = do\n--   (s1,xs',ss') <- chainBackwardWithState f (s0,xs)\n--   (sFin, x') <- f (s1,x)\n--   return (sFin,(x':**xs'),(sFin:**ss'))\n\n-- | RNN helper\ntransposeV :: forall n xs t. All KnownShape xs => KnownNat n =>\n               SList xs -> V n (HTV t xs) -> HTV t (Ap (FMap (Cons n)) xs)\ntransposeV Unit _ = Unit\ntransposeV (_ :* n) xxs  = F ys' :* yys'\n  where (ys,yys) = help @(Tail xs) xxs\n        ys' = stack0 ys\n        yys' = transposeV n yys\n        help :: forall ys x tt. V n (HTV tt (x ': ys)) -> (V n (T x tt) , V n (HTV tt ys))\n        help (xs) = ((fmap (fromF . hhead) xs),(fmap htail xs))\n\n-- | @(gatherFinalStates dynLen states)[i] = states[dynLen[i]-1]@\ngatherFinalStates :: KnownShape x => KnownNat n => T '[] Int32 -> T (n ': x) t -> T x t\ngatherFinalStates dynLen states = gather states (dynLen ⊝ constant 1)\n\ngathers :: forall n xs t. All KnownShape xs => KnownNat n =>\n            SList xs -> T '[] Int32 -> HTV t (Ap (FMap (Cons n)) xs) -> HTV t xs\ngathers Unit _ Unit = Unit\ngathers (_ :* n) ixs (F x :* xs) = F (gatherFinalStates ixs x) :* gathers @n n ixs xs\n\n-- | @rnnWithCull dynLen@ constructs an RNN as normal, but returns the\n-- state after step @dynLen@ only.\niterateWithCull :: forall n x y ls b.\n  KnownLen ls => KnownNat n => All KnownShape ls =>\n  T '[] Int32 -- ^ dynamic length\n  -> RnnCell b ls x y -> Rnn n b ls x y\niterateWithCull dynLen cell xs = C $ \\s0 ->\n  let (us,ss) = chainForwardWithState (uncurry (flip (runC . cell))) (s0,xs)\n      sss = transposeV @n (typeSList @ls) ss\n  in (gathers @n (typeSList @ls) dynLen sss,us)\n\n-- -- | Like @rnnWithCull@, but states are threaded backwards.\n-- rnnBackwardsWithCull :: forall n bs x y ls b.\n--   KnownLen ls => KnownNat n => All KnownLen ls => All (LastEqual bs) ls =>\n--   T '[bs] Int32 -> RnnCell b ls x y -> RNN n b ls x y\n-- rnnBackwardsWithCull dynLen cell (s0, t) = do\n--   (us,ss) <- chainBackwardWithState cell (s0,xs)\n--   let sss = transposeV @n (shapeSList @ls) ss\n--   return (gathers @n (shapeSList @ls) (n - dynLen) sss,us)\n"
  },
  {
    "path": "TypedFlow/Layers/RNN/Cells.hs",
    "content": "{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE ViewPatterns #-}\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-|\nModule      : TypedFlow.Layers.RNN.Cells\nDescription : RNN cells\nCopyright   : (c) Jean-Philippe Bernardy, 2017\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n-}\n\n\nmodule TypedFlow.Layers.RNN.Cells (\n  -- * RNN Cells\n  cellInitializerBit,\n  LSTMP(..),\n  lstm,\n  GRUP(..),\n  gru,\n  StackP(..),\n  stackRU,\n  ) where\n\nimport TypedFlow.Layers.RNN.Base\nimport TypedFlow.TF\nimport TypedFlow.Types\nimport TypedFlow.Types.Proofs\nimport GHC.TypeLits\nimport TypedFlow.Layers.Core (DenseP(..),(#))\nimport Prelude hiding (RealFrac(..))\n\n--------------------------------------\n-- Cells\n\n-- | Standard RNN gate initializer. (The recurrent kernel is\n-- orthogonal to avoid divergence; the input kernel is glorot)\ncellInitializerBit :: ∀ n x t. (KnownNat n, KnownNat x, KnownFloat t) => Gen (DenseP t (n + x) n)\ncellInitializerBit = DenseP <$> (concat0 <$> recurrentInitializer <*> kernelInitializer) <*> biasInitializer\n  where recurrentInitializer :: Gen (Tensor '[n, n] t)\n        recurrentInitializer = noise $ OrthogonalD\n        kernelInitializer :: Gen (Tensor '[x, n] t)\n        kernelInitializer = glorotUniform\n        biasInitializer = pure zeros\n\n-- | Parameter for an LSTM\ndata LSTMP t n x = LSTMP (DenseP t (n+x) n) (DenseP t (n+x) n) (DenseP t (n+x) n) (DenseP t (n+x) n)\n\ninstance (KnownNat n, KnownNat x, KnownFloat t) => KnownTensors (LSTMP t n x) where\n  travTensor f s (LSTMP x y z w) = LSTMP <$> travTensor f (s<>\"_f\") x <*> travTensor f (s<>\"_i\") y <*> travTensor f (s<>\"_c\") z <*> travTensor f (s<>\"_o\") w\ninstance (KnownNat n, KnownNat x, KnownFloat t) => ParamWithDefault (LSTMP t n x) where\n  defaultInitializer = LSTMP <$> forgetInit <*> cellInitializerBit <*> cellInitializerBit <*> cellInitializerBit\n    where forgetInit = DenseP <$> (denseWeights <$> cellInitializerBit) <*> pure ones\n\n-- | Standard LSTM\nlstm :: ∀ n x t. KnownNat x => KnownNat n => KnownFloat t\n  => LSTMP t n x -> RnnCell t '[ '[n], '[n]] (Tensor '[x] t) (Tensor '[n] t)\nlstm (LSTMP wf wi wc wo) input = C $ \\(VecPair ht1 ct1) -> \n  let f = sigmoid (wf # hx)\n      hx = (concat0 ht1 input)\n      i = sigmoid (wi # hx)\n      cTilda = tanh (wc # hx)\n      o = sigmoid (wo # hx)\n      c = ((f ⊙ ct1) + (i ⊙ cTilda))\n      h = (o ⊙ tanh c)\n  in (VecPair h c, h)\n\n-- | Parameter for a GRU\ndata GRUP t n x = GRUP (T [n+x,n] t) (T [n+x,n] t) (T [n+x,n] t)\n\ninstance (KnownNat n, KnownNat x, KnownFloat t) => KnownTensors (GRUP t n x) where\n  travTensor f s (GRUP x y z) = GRUP <$> travTensor f (s<>\"_z\") x <*> travTensor f (s<>\"_r\") y <*> travTensor f (s<>\"_w\") z\ninstance (KnownNat n, KnownNat x, KnownFloat t) => ParamWithDefault (GRUP t n x) where\n  defaultInitializer = GRUP <$> (denseWeights <$> cellInitializerBit) <*> (denseWeights <$> cellInitializerBit) <*> (denseWeights <$> cellInitializerBit)\n\n\n\n-- | Standard GRU cell\ngru :: ∀ n x t. KnownNat x => (KnownNat n, KnownFloat t) => GRUP t n x ->\n        RnnCell t '[ '[n] ] (Tensor '[x] t) (Tensor '[n] t)\ngru (GRUP wz wr w) xt = C $ \\(VecSing ht1) ->\n  let hx =  (concat0 ht1 xt)\n      zt = sigmoid (wz ∙ hx)\n      rt = sigmoid (wr ∙ hx)\n      hTilda = tanh (w ∙ (concat0 (rt ⊙ ht1) xt))\n      ht = ((ones ⊝ zt) ⊙ ht1 + zt ⊙ hTilda)\n  in (VecSing ht, ht)\n\n\ndata StackP w n = StackP (DenseP w (n + n) 3)\n\ndefStackP :: KnownNat n => KnownFloat w => Gen (StackP w n)\ndefStackP = StackP <$> defaultInitializer\n  -- (DenseP glorotUniform (stack0 (V [zeros, constant (-1), zeros]) )) -- demote popping a bit \n\ninstance (KnownNat n, KnownTyp w) => KnownTensors (StackP w n) where\n  travTensor f s (StackP d) = StackP <$> travTensor f s d\n\ninstance (KnownNat n, KnownFloat w) => (ParamWithDefault (StackP w n)) where\n  defaultInitializer = defStackP\n\n-- | A stack recurrent unit. The input has two purposes: 1. it is\n-- saved in a stack. 2. it controls (a dense layer which gives) the\n-- operation to apply on the stack.  The first type argument is the\n-- depth of the stack.\nstackRU :: ∀k n bs w. KnownNat k => KnownNat n => (KnownNat bs) => (KnownFloat w) => StackP w n ->\n        RnnCell w '[ '[k+1,n]] (Tensor '[n] w) (Tensor '[n] w)\nstackRU (StackP w) input = C $ \\(VecSing st1) ->\n  succPos @k #>\n  plusMono @k @1 #>\n  plusComm @k @1 #>\n  termCancelation @k @1 #>\n  let ct1 = nth0' @0 st1\n      hx = concat0 ct1 input\n      action :: T '[3] w\n      action = softmax0 (w # hx)\n      tl :: T '[k,n] w\n      tl = slice0 @1 @(k+1) st1\n      it :: T '[k,n] w\n      it = slice0 @0 @k  st1\n      stTilda :: T '[3,k+1,n] w\n      stTilda = stack0 (st1 :**  (tl `concat0` zeros) :** (expandDim0 input `concat0` it) :** VUnit)\n      st :: T '[k+1,n] w\n      st = inflate2 (flatten12 stTilda ∙ action)\n      ct = nth0' @0 st\n  in (VecSing st, ct)\n\n"
  },
  {
    "path": "TypedFlow/Layers/RNN.hs",
    "content": "{-|\nModule      : TypedFlow.Layers.RNN\nDescription : RNN cells, layers and combinators.\nCopyright   : (c) Jean-Philippe Bernardy, 2017\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n-}\n\n\nmodule TypedFlow.Layers.RNN (\n  module TypedFlow.Layers.RNN.Base,\n  module TypedFlow.Layers.RNN.Cells,\n  module TypedFlow.Layers.RNN.Attention)  where\n\nimport TypedFlow.Layers.RNN.Base\nimport TypedFlow.Layers.RNN.Cells\nimport TypedFlow.Layers.RNN.Attention\n"
  },
  {
    "path": "TypedFlow/Layers.hs",
    "content": "\nmodule TypedFlow.Layers\n  (module  TypedFlow.Layers.Core\n  ,module  TypedFlow.Layers.RNN\n  ) where\n\nimport TypedFlow.Layers.Core\nimport TypedFlow.Layers.RNN\n\n"
  },
  {
    "path": "TypedFlow/Learn.hs",
    "content": "{-# LANGUAGE LambdaCase #-}\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE PatternSynonyms #-}\n{-|\nModule      : TypedFlow.Learn\nDescription : Loss functions and optimization strategies\nCopyright   : (c) Jean-Philippe Bernardy, 2017\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE ApplicativeDo #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFoldable #-}\n{-# LANGUAGE DeriveFunctor #-}\n{-# LANGUAGE DeriveTraversable #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE GADTs #-}\n{-# LANGUAGE GeneralizedNewtypeDeriving #-}\n{-# LANGUAGE InstanceSigs #-}\n{-# LANGUAGE MagicHash #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE StandaloneDeriving #-}\n{-# LANGUAGE TupleSections #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeInType #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE UndecidableSuperClasses #-}\n{-# LANGUAGE UnicodeSyntax #-}\n\nmodule TypedFlow.Learn\n  (-- losses:\n    sparseCategorical, binary, timedCategorical, categoricalDistribution,sparseCategoricalDensePredictions,\n    -- types\n    Options(..), defaultOptions,\n    Function(..),Model,ModelOutput,\n    PreparedFunction(..), PreparedModel(..),\n    -- other\n    simpleModel, modelFunction, probeFunction,\n    addRegularizer,\n    prepare,\n    -- utils\n    placeholderName,\n  ) where\n\nimport Data.Proxy\nimport TypedFlow.Types\nimport TypedFlow.Types.Proofs (knownAppend,  (?>), )\nimport TypedFlow.Broadcast (doBroadcast, mapPlaceHolders, ConsSh,doBroadcastSingle)\nimport TypedFlow.Abstract (doExtractVars)\nimport TypedFlow.TF\nimport Prelude hiding (RealFrac(..))\nimport GHC.TypeLits\n\n-- | Triple of values that are always output in a model: prediction, loss and accuracy.\n-- @t@ is the type of the prediction.\n-- @s@ is the shape of the loss and accuracy\ntype ModelOutput t predictionShape s\n  = Placeholders '[ '(\"loss\",s,Float32) -- loss associated with the prediction\n                  , '(\"accuracy\",s,Float32)  -- is the prediction correct?\n                  , '(\"y_\",s++predictionShape,t) -- prediction (which can contain prediction-shaped info)\n                  ]\n\npattern ModelOutput ::  T (s++predictionShape) t -> T s Float32 -> T s Float32 -> ModelOutput t predictionShape s\npattern ModelOutput y loss accur = PHT loss :* PHT accur :* PHT y :* Unit\n\n-- | A standard modelling function: (input value, gold value) ↦ (prediction, accuracy, loss).\n-- input is the shape of the input.\n-- output is the shape of the output (one element per individual loss and accuracy)\n-- p is the shape of each output element.\n-- g is the shape of each gold output --- often equal to p.\ntype Model input tIn g p output tOut\n  = T input tIn -> T (g++output) tOut -> ModelOutput tOut p output\n\n-- | First type argument is the number of classes.  @categorical\n-- logits gold@ return (prediction, accuraccy, loss)\n\nsparseCategorical :: forall nCat. KnownNat nCat => Model '[nCat] Float32 '[] '[] '[] Int32\nsparseCategorical logits y =\n  let y_ = argmax0 logits\n      modelCorrect = cast (equal y_ y)\n      modelLoss = sparseSoftmaxCrossEntropyWithLogits y logits\n  in ModelOutput y_ modelLoss modelCorrect\n\n-- | First type argument is the number of classes.  @categorical\n-- logits gold@ return (prediction, accuracy, loss)\nsparseCategoricalDensePredictions :: forall nCat. KnownNat nCat\n  => Tensor '[nCat] Float32\n  -> Tensor '[] Int32\n  -> ModelOutput  Float32 '[nCat] '[]\nsparseCategoricalDensePredictions logits y =\n  let y_ :: T '[nCat] Float32\n      y_ = softmax0 logits\n      modelCorrect = cast (equal (argmax0 logits) y)\n      modelLoss = sparseSoftmaxCrossEntropyWithLogits y logits\n  in ModelOutput y_ modelLoss modelCorrect\n\n\n-- | First type argument is the number of classes.\n-- @categoricalDistribution logits gold@ return (prediction,\n-- accuraccy, loss) accuracy is reported as predicting the same class\n-- as the input 'winning' class.\ncategoricalDistribution :: forall nCat. KnownNat nCat => Model '[nCat] Float32 '[nCat] '[nCat] '[] Float32\ncategoricalDistribution logits y =\n  ModelOutput (softmax0 logits)\n              (softmaxCrossEntropyWithLogits y logits)\n              (cast (equal (argmax0 @'B32 logits) (argmax0 y)))\n  \n\n-- | @timedCategorical targetWeights logits y@\n--\n-- targetWeights: a zero-one matrix of the same size as\n-- decoder_outputs. It is intended to mask padding positions outside\n-- of the target sequence lengths with values 0.\n--\n-- Note that the accuracy is computed by multiplying the accuracies at\n-- individual time steps with the targetWeights.\n\ntimedCategorical :: forall len nCat bits. KnownNat nCat => KnownNat len => KnownBits bits =>\n  Tensor '[len] (Flt bits) -> Tensor '[len,nCat] (Flt bits) -> Tensor '[len] Int32 -> ModelOutput  (Flt bits) '[len,nCat] '[]\ntimedCategorical targetWeights logits y =\n  let y_ :: Tensor '[len] Int32\n      y_ = argmax1 logits\n      modelY = softmax1 logits\n      -- correct prediction for each position\n      correctPrediction :: Tensor '[len] TFBool\n      correctPrediction = equal y_ y\n      -- total number of correct predictions\n      correctPredictionWeighted :: Tensor '[] (Flt bits)\n      correctPredictionWeighted = reduceSumAll (cast @(Flt bits) correctPrediction ⊙ targetWeights)\n      weightSum = reduceSumAll targetWeights\n      modelCorrect :: Tensor '[] Float32\n      modelCorrect = cast (correctPredictionWeighted / weightSum)\n      crossEntropies = zipWithT sparseSoftmaxCrossEntropyWithLogits y logits\n      modelLoss = cast @Float32 (reduceSumAll (crossEntropies ⊙ targetWeights) / weightSum)\n  in ModelOutput modelY modelLoss modelCorrect\n\n-- | Model with @n@ binary outputs.\nbinary :: KnownNat n => Model '[n] Float32 '[] '[] '[n] Int32\nbinary logits y =\n  let y_ = cast @Int32 (round sigy_)\n      sigy_ = sigmoid logits\n  in ModelOutput (y_)\n                 (sigmoidCrossEntropyWithLogits (cast @Float32 y) logits)\n                 (cast (equal y_ y))\n\n-- | Model compiler options\ndata Options = Options {maxGradientNorm :: Maybe Prelude.Float -- ^ apply gradient clipping\n                       }\n\n-- | default model compiler options\ndefaultOptions :: Options\ndefaultOptions = Options {maxGradientNorm = Nothing}\n\ntype family Concatenate xs where\n  Concatenate (x ': xs) = x ++ Concatenate xs\n  Concatenate '[] = '[]\n\ngenPlaceholders :: All KnownPlaceholder shapesAndTypes => SList shapesAndTypes -> Placeholders shapesAndTypes\ngenPlaceholders Unit = Unit\ngenPlaceholders (ph :* names) = PHT (T (ExternalVar (Ref (placeholderName ph) typeSShape typeSTyp))) :* genPlaceholders names\n\nplaceholderName :: forall (ph :: PH)  p. KnownPlaceholder ph => p ph -> String\nplaceholderName proxy = refName (placeHolderRef proxy)\n\nsimpleModel :: forall p sx tx sy ty sy_ ty_.\n               (KnownShape sy_, KnownShape p, KnownShape sx, KnownTyp ty_, KnownShape sy, KnownTyp tx, KnownTyp ty)\n            => (Tensor sx tx -> Tensor sy ty -> ModelOutput  ty_ p sy_)\n            -> Function \nsimpleModel f = knownAppend @sy_ @p ?> modelFunction \"runModel\" f'\n  where f' :: Placeholders '[ '(\"x\",sx,tx), '(\"y\",sy,ty)] -> ModelOutput ty_ p sy_\n        f' (PHT x :* PHT y :* Unit) = f x y\n\n\n-- | Add a term to the loss. This function is intendend to add\n-- regularizers, ie. losses that do not depend on the predicted\n-- output, but rather on the structure of a parameter.\naddRegularizer :: Scalar Float32 -> Gen ()\naddRegularizer r = GPState  $ \\GState{..} -> ((),GState{genRegularizers=r:genRegularizers,..})\n\n\n       \nknownBatchModel :: forall n ps. KnownNat n => NP (Sat KnownPlaceholder) ps -> NP (Sat KnownPlaceholder) (Ap (FMap (ConsSh n)) ps)\nknownBatchModel Unit = Unit\nknownBatchModel (Comp Dict :* xs) = Sat :* knownBatchModel @n xs\n\n-- | take the mean of loss/accur over the batch, etc. and add regulariser to loss\nconsolidate :: forall s rest. KnownShape s\n            => Scalar Float32\n            -> Placeholders ( '(\"loss\",s  ,Float32) ': '(\"accuracy\",s  ,Float32) ': rest)\n            -> Placeholders ( '(\"loss\",'[],Float32) ': '(\"accuracy\",'[],Float32) ': rest)\nconsolidate extraLoss (PHT loss :* PHT accur :* rest) = (PHT (reduceMeanAll loss + extraLoss) :* PHT (reduceMeanAll accur) :* rest)\n\nclass (All KnownPlaceholder ps, KnownLen ps) => KnownPHS ps\ninstance (All KnownPlaceholder ps, KnownLen ps) => KnownPHS ps\n\ndata PreparedFunction = PreparedFunction {pfName :: String,\n                                          pfBatched :: Bool,\n                                          pfInputs, pfOutputs :: SomeSuch KnownPHS Placeholders}\ndata PreparedModel = PreparedModel {pmBatchSize :: Integer,\n                                    pmParams :: [VarInfo],\n                                    pmFunctions :: [PreparedFunction]\n                                   }\n\n-- | Prepare compilation of a model by:\n-- extracting and exposing parameters \n-- batching the model\n-- exposing placeholders\n-- consolidating loss and accuracy\n-- adding regularizers to the loss\nprepare :: forall bs. (KnownNat bs)\n        => Gen [Function]\n        -> PreparedModel\nprepare fGen =\n  PreparedModel\n    {pmBatchSize = natVal (Proxy @bs)\n    ,pmParams = [VarInfo{varInitial=fmap doBroadcastSingle varInitial,..} | VarInfo{..} <- filter varTrainable vars]\n    ,pmFunctions = flip map fs $ \\case\n        ModelFn nm st1 st2 f ->\n          knownAll (knownBatchModel @bs st1) $\n          knownAll (knownBatchModel @bs st2) $\n          knownAll st1 $ \n          knownAll st2 $ \n          let placeHolders = genPlaceholders typeSList\n              u = -777 -- magic unique identifier for the batch dimension\n          in PreparedFunction nm\n               True\n               (SomeSuch placeHolders)\n               (SomeSuch $ doBroadcast (consolidate {-@(bs ': s) @(BPH bs st2)-} regular (mapPlaceHolders @bs u True f placeHolders)))\n        ProbeFn nm st1 st2 f -> \n          knownAll st1 $\n          knownAll st2 $\n          let placeHolders = genPlaceholders typeSList\n          in PreparedFunction nm False (SomeSuch placeHolders) (SomeSuch (doBroadcast (f placeHolders)))\n    }\n  where (fs,finalState,vars) = doExtractVars fGen\n        regular = sum (genRegularizers finalState)\n\ndata Function where\n  ModelFn :: (KnownShape s, KnownLen st1, KnownLen st2)\n          => String\n          -> NP (Sat KnownPlaceholder) st1 -> NP (Sat KnownPlaceholder) st2 \n          -> (Placeholders st1 -> Placeholders ('(\"loss\",s,Float32) ': '(\"accuracy\",s,Float32) ': st2)) -> Function\n  ProbeFn :: (KnownLen st1, KnownLen st2, All KnownPlaceholder st1, All KnownPlaceholder st2)\n          => String\n          -> NP (Sat KnownPlaceholder) st1 -> NP (Sat KnownPlaceholder) st2 \n          -> (Placeholders st1 -> Placeholders st2) -> Function\n\nmodelFunction :: (KnownShape s, KnownLen st1, KnownLen st2, All KnownPlaceholder st1, All KnownPlaceholder st2)\n          => String\n          -> (Placeholders st1 -> Placeholders ('(\"loss\",s,Float32) ': '(\"accuracy\",s,Float32) ': st2)) -> Function\nmodelFunction nm f = ModelFn nm (allKnown @KnownPlaceholder) (allKnown @KnownPlaceholder) f\n\n\nprobeFunction :: (KnownLen st1, KnownLen st2, All KnownPlaceholder st1, All KnownPlaceholder st2)\n          => String\n          -> (Placeholders st1 -> Placeholders st2) -> Function\nprobeFunction nm f = ProbeFn nm (allKnown @KnownPlaceholder) (allKnown @KnownPlaceholder) f\n\n\n"
  },
  {
    "path": "TypedFlow/Memo.hs",
    "content": "{-# LANGUAGE TypeInType #-}\n{-# LANGUAGE PolyKinds #-}\n{-# LANGUAGE KindSignatures #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE GADTs #-}\nmodule TypedFlow.Memo where\n\nimport qualified Data.IntMap as I\nimport qualified Data.Map.Strict as M\nimport System.Mem.StableName\nimport Data.IORef\nimport System.IO.Unsafe\nimport Unsafe.Coerce\nimport Data.Kind (Type)\ntype SNMap k v = I.IntMap [(StableName k,v)]\n\nsnMapLookup :: StableName k -> SNMap k v -> Maybe v\nsnMapLookup sn m = do\n  x <- I.lookup (hashStableName sn) m\n  lookup sn x\n\nsnMapInsert :: StableName k -> v -> SNMap k v -> SNMap k v\nsnMapInsert sn res = I.insertWith (++) (hashStableName sn) [(sn,res)]\n\nmemo :: (a -> b) -> a -> b\nmemo f = unsafePerformIO (\n  do { tref <- newIORef (I.empty)\n     ; return (applyStable f tref)\n     })\n\napplyStable :: (a -> b) -> IORef (SNMap a b) -> a -> b\napplyStable f tbl arg = unsafePerformIO (\n  do { sn <- makeStableName arg\n     ; lkp <- snMapLookup sn <$> readIORef tbl\n     ; case lkp of\n         Just result -> return result\n         Nothing ->\n           do { let res = f arg\n              ; modifyIORef tbl (snMapInsert sn res)\n              ; return res\n              }})\n\nmemoOrd :: Ord a => (a -> b) -> a -> b\nmemoOrd f = unsafePerformIO (\n  do { tref <- newIORef (M.empty)\n     ; return (applyStableOrd f tref)\n     })\n\napplyStableOrd :: Ord a => (a -> b) -> IORef (M.Map a b) -> a -> b\napplyStableOrd f tbl arg = unsafePerformIO (\n  do { lkp <- M.lookup arg <$> readIORef tbl\n     ; case lkp of\n         Just result -> return result\n         Nothing ->\n           do { let res = f arg\n              ; modifyIORef tbl (M.insert arg res)\n              ; return res\n              }})\n\n\ndata Some2 k1 k2 (f :: k1 -> k2 -> Type) where\n  Some2 :: forall k1 k2 f a b. StableName (f a b) -> Some2 k1 k2 f\n\ninstance Eq (Some2 k1 k2 f) where\n  Some2 sn1 == Some2 sn2 = eqStableName sn1 sn2\n\ntype SSNMap2 k1 k2 (f :: k1 -> k2 -> Type) v = I.IntMap [(Some2 k1 k2 f,v)]\n\nmakeSn2 :: f a b -> Some2 k1 k2 f\nmakeSn2 = Some2 . unsafePerformIO . makeStableName\n\nsnMapLookup2 :: Some2 k1 k2 f -> SSNMap2 k1 k2 f v -> Maybe v\nsnMapLookup2 (Some2 sn) m = do\n  x <- I.lookup (hashStableName sn) m\n  lookup (Some2 sn) x\n\nsnMapInsert2 :: Some2 k1 k2 f -> v -> SSNMap2 k1 k2 f v -> SSNMap2 k1 k2 f v\nsnMapInsert2 (Some2 sn) res = I.insertWith (++) (hashStableName sn) [(Some2 sn,res)]\n\ndata KV k1 k2 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type)  where\n  KV :: forall k1 k2 f v a b. StableName (f a b) -> v a b -> KV k1 k2 f v\n\ntype SNMap22 k1 k2 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) = I.IntMap [KV k1 k2 f v]\n\nsnMap22Lookup :: StableName (f a b) -> SNMap22 k1 k2 f v -> Maybe (v a b)\nsnMap22Lookup sn  m = do\n  x <- I.lookup (hashStableName sn) m\n  lkKV sn x\n\nlkKV :: StableName (f a b) -> [KV k1 k2 f v] -> Maybe (v a b)\nlkKV _ [] = Nothing\nlkKV sn (KV sn' v:kvs) | eqStableName sn sn' = Just (unsafeCoerce v) -- sn == sn' -> a == a' and b == b' \n                       | otherwise = lkKV sn kvs\n\nsnMap22Insert :: KV k1 k2 f v -> SNMap22 k1 k2 f v -> SNMap22 k1 k2 f v\nsnMap22Insert (KV sn res) = I.insertWith (++) (hashStableName sn) [KV sn res]\n\n\n-- | The type of a memo table for functions of a.\ntype Memo a = forall r. (a -> r) -> (a -> r)\n\n-- | Memoize a two argument function (just apply the table directly for\n-- single argument functions).\nmemo2 :: Memo a -> Memo b -> (a -> b -> r) -> (a -> b -> r)\nmemo2 a b = a . (b .)\n\n-- | Memoize a three argument function.\nmemo3 :: Memo a -> Memo b -> Memo c -> (a -> b -> c -> r) -> (a -> b -> c -> r)\nmemo3 a b c = a . (memo2 b c .)\n"
  },
  {
    "path": "TypedFlow/Memo2.hs",
    "content": "{-# LANGUAGE GeneralizedNewtypeDeriving #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE LambdaCase #-}\n{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE PolyKinds #-}\n{-# LANGUAGE KindSignatures #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE GADTs #-}\n\nmodule TypedFlow.Memo2 where\n\nimport Data.Kind (Type)\nimport qualified Data.Map.Strict as M\nimport System.Mem.StableName\n-- import Data.IORef\n-- import System.IO.Unsafe\nimport Unsafe.Coerce\nimport qualified Data.IntMap as I\nimport Data.Type.Equality\nimport Control.Monad.IO.Class\nimport Data.IORef\nimport TypedFlow.Types.Proofs (SingEq(..))\nimport Data.List (intercalate)\n\ndata Map0 k (m :: Type -> Type) f v = forall . Map0 {\n  m0Key :: f -> IO k,\n  m0Empty :: m v,\n  m0lk  :: k -> m v -> Maybe v,\n  m0upd :: k -> (Maybe v -> v) -> m v -> m v,\n  m0fmap :: forall u w.  (u -> w) -> m u -> m w,\n  m0showKey :: k -> String,\n  m0showTbl :: (v -> String) -> (m v -> String)\n  }\n\n\ndata Map1 (k :: k1 -> Type) (m :: (k1 -> Type) -> Type)  (f :: k1 -> Type) (v :: k1 -> Type) = Map1 {\n  m1Key :: forall x. f x -> IO (k x),\n  m1Empty :: m v,\n  m1lk  :: forall x. k x -> m v -> Maybe (v x),\n  m1upd :: forall x. k x -> (Maybe (v x) -> (v x)) -> m v -> m v,\n  m1showKey :: forall x . k x -> String,\n  m1showTbl :: (forall x . v x -> String) -> (m v -> String)\n  }\n\ndata Map2 (k :: k1 -> k2 -> Type) (m :: (k1 -> k2 -> Type) -> Type)  (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) = Map2 {\n  m2Key :: forall x y. f x y -> IO (k x y),\n  m2Empty :: m v,\n  m2lk  :: forall x y. k x y -> m v -> Maybe (v x y),\n  m2upd :: forall x y. k x y -> (Maybe (v x y) -> (v x y)) -> m v -> m v,\n  -- m2fmap :: forall u w.  (forall x y. u x y -> w x y) -> m u -> m w,\n  m2showKey :: forall x y. k x y -> String,\n  m2showTbl :: (forall x y. v x y -> String) -> (m v -> String)\n  }\n\ndata Map3 (k :: k1 -> k2 -> k3 -> Type) (m :: (k1 -> k2 -> k3 -> Type) -> Type)  (f :: k1 -> k2 -> k3 -> Type) (v :: k1 -> k2 -> k3 -> Type) = Map3 {\n  m3Key :: forall x y z. f x y z -> IO (k x y z),\n  m3Empty :: m v,\n  m3lk  :: forall x y z. k x y z -> m v -> Maybe (v x y z),\n  m3upd :: forall x y z. k x y z -> (Maybe (v x y z) -> (v x y z)) -> m v -> m v,\n  m3showKey :: forall x y z. k x y z -> String,\n  m3showTbl :: (forall x y z. v x y z -> String) -> (m v -> String)\n  }\n\nnewtype Id x = Id x\n\nordMap :: forall k b. (Ord k, Show k) => Map0 k (M.Map k) k b\nordMap = Map0 {..} where\n  m0Key = return\n  m0Empty = mempty\n  m0lk k = M.lookup k\n  m0upd k f m = M.alter (Just . f) k m\n  m0fmap = fmap\n  m0showKey = show\n  m0showTbl sh m = intercalate \";\" [(show k) <> \"↦\" <> (sh v) | (k,v) <- M.assocs m]\n\ndata Single1 f g where\n  None1 :: Single1 f g\n  Single1 :: f a -> g a -> Single1 f g \n\nverifMap1 :: forall k v. SingEq k => Map1 k (Single1 k) k v\nverifMap1 = Map1 {..} where\n  m1Key = return\n  m1Empty = None1\n  m1lk :: k a -> Single1 k b -> Maybe (b a)\n  m1lk k = \\case\n    None1 -> Nothing\n    Single1 k' v -> case testEq k k' of\n      Just Refl -> Just v\n      Nothing -> error \"verifMap1: mismatching keys! (1)\"\n  m1upd :: forall x. k x -> (Maybe (v x) -> (v x)) -> Single1 k v -> Single1 k v\n  m1upd k f None1 = Single1 k (f Nothing)\n  m1upd k f (Single1 k' v) = case testEq k k' of\n      Just Refl -> Single1 k (f (Just v))\n      Nothing -> error \"verifMap1: mismatching keys! (2)\"\n  m1showKey _ = \"#\"\n  m1showTbl :: (forall x . v x -> String) -> (Single1 k v -> String)\n  m1showTbl _ None1 = \"·\"\n  m1showTbl h (Single1 _ v) = \"!\" <> (h v)\n\n\ntestStable :: StableName a -> StableName b -> Maybe (a :~: b)\ntestStable sn sn' | eqStableName sn sn' = Just (unsafeCoerce Refl)\n                  | otherwise = Nothing\n\nsnMap2 :: forall f v. Map2 (SN2 f) (SNMap22 f) f v\nsnMap2 = Map2 {..} where\n  m2showTbl :: (forall x y. v x y -> String) -> (SNMap22 f v -> String)\n  m2showTbl h (SNMap22 m) = intercalate \",\" [ m2showKey k <> \"↦\" <> h v | e <- I.elems m, KV k v <- e   ]\n  m2showKey (SN2 sn) = show (hashStableName sn)\n  m2Key obj = SN2 <$> makeStableName obj\n  m2Empty = mempty\n  m2lk = snMap22Lookup\n  m2upd :: SN2 f x y -> (Maybe (v x y) -> (v x y)) -> SNMap22 f v -> SNMap22 f v\n  m2upd (SN2 sn) f (SNMap22 m) = SNMap22 $\n                                 I.alter (\\case Nothing -> Just [KV (SN2 sn) (f Nothing)]\n                                                Just p -> Just (updKV (SN2 sn) f p))\n                                 (hashStableName sn)\n                                 m\n\n  updKV :: SN2 f' x y -> (Maybe (v' x y) -> (v' x y)) -> [KV k1 k2 (SN2 f') v'] -> [KV k1 k2 (SN2 f') v']\n  updKV (SN2 sn) f [] = [KV (SN2 sn) (f Nothing)]\n  updKV (SN2 sn) f (v@(KV (SN2 sn') x):xs) = case testStable sn sn' of\n    Just Refl -> KV (SN2 sn') (f (Just x)):xs\n    Nothing -> v : updKV (SN2 sn) f xs\n                                 \n  -- m2fmap :: forall u w.  (forall x y. u x y -> w x y) -> SNMap22 f u -> SNMap22 f w\n  -- m2fmap h (SNMap22 t) = SNMap22 (fmap (fmap (\\(KV k v) -> KV k (h v))) t)\n\n  snMap22Lookup :: forall a b f' v'. SN2 f' a b -> SNMap22 f' v' -> Maybe (v' a b)\n  snMap22Lookup (SN2 sn) (SNMap22 m) = do\n    x <- I.lookup (hashStableName sn) m\n    lkKV sn x\n\n  lkKV :: forall k1 k2 f' v' a b . StableName (f' a b) -> [KV k1 k2 (SN2 f') v'] -> Maybe (v' a b)\n  lkKV _ [] = Nothing\n  lkKV sn (KV (SN2 sn') v:kvs) = case testStable sn sn' of\n                             Just Refl ->  Just (unsafeCoerce v) -- sn == sn' -> a == a' and b == b' \n                             Nothing ->  lkKV sn kvs\n\n\ndata KV k1 k2 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type)  where\n  KV :: forall k1 k2 f v a b. f a b -> v a b -> KV k1 k2 f v\n\nnewtype SNMap22  (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) = SNMap22 (I.IntMap [KV k1 k2 (SN2 f) v]) deriving (Monoid, Semigroup)\n\nnewtype SN2 (f :: k1 -> k2 -> Type) a b = SN2 (StableName (f a b)) \n\ndata (:.:) (m1 :: k2 -> Type) (m2 :: k1 -> k2) (h :: k1) = Comp (m1 (m2 h))\n\n\ndata Sig02 f g x y where\n  Ex02 :: f -> g x y -> Sig02 f g x y\n\ndata Sig03 f g x y z where\n  Ex03 :: f -> g x y z -> Sig03 f g x y z\n\ndata Sig12 f g x y z where\n  Ex12 :: f x -> g y z -> Sig12 f g x y z\n\ndata Sig22 f g x y where\n  Ex22 :: f x y -> g x y -> Sig22 f g x y\n\ndata P33 f g x y z where\n  T33 :: f x y z -> g x y z -> P33 f g x y z\n\n\n\ncontaining00 :: (forall v. Map0 k1 m1 f v) -> Map0 k2 m2 g h -> Map0 (k1,k2) (m1 :.: m2)  (f,g) h\ncontaining00 f g  = Map0\n   {\n   m0Key = (\\(a,b) -> (,) <$> m0Key f a <*> m0Key g b),\n   m0Empty = Comp (m0Empty f),\n   m0lk = \\(k1,k2) (Comp t) -> do t' <- m0lk f k1 t; m0lk g k2 t',\n   m0upd = \\(k1,k2) h (Comp t) -> Comp $ m0upd f k1 (m0upd g k2 h . \\case Just tb -> tb; Nothing -> (m0Empty g)) t,\n   m0fmap = \\h (Comp t) -> Comp $ m0fmap f (m0fmap g h) t,\n   m0showKey = \\(k1,k0) -> m0showKey f k1 <> \",\" <> m0showKey g k0,\n   m0showTbl = \\h (Comp t) -> m0showTbl f (m0showTbl g h) t\n   }                      \n\ncontaining02 :: (forall v. Map0 k1 m1 f v) -> Map2 k2 m2 g h -> Map2 (Sig02 k1 k2) (m1 :.: m2) (Sig02 f g)  h\ncontaining02 f g = Map2\n   {\n   m2Key = (\\(Ex02 a b) -> Ex02 <$> m0Key f a <*> m2Key g b),\n   m2Empty = Comp (m0Empty f),\n   m2lk = \\(Ex02 k1 k2) (Comp t) -> do t' <- m0lk f k1 t; m2lk g k2 t',\n   m2upd = \\(Ex02 k1 k2) h (Comp t) -> Comp $ m0upd f k1 (m2upd g k2 h . \\case Just tb -> tb; Nothing -> (m2Empty g)) t,\n   -- m2fmap = \\h (Comp t) -> Comp $ m0fmap f (m2fmap g h) t,\n   m2showKey = \\(Ex02 k1 k2) -> m0showKey f k1 <> \",\" <> m2showKey g k2,\n   m2showTbl = \\h (Comp t) -> m0showTbl f (m2showTbl g h) t\n   }                      \n\ncontaining03 :: (forall v. Map0 k1 m1 f v) -> Map3 k2 m2 g h -> Map3 (Sig03 k1 k2) (m1 :.: m2) (Sig03 f g)  h\ncontaining03 f g = Map3\n   {\n   m3Key = (\\(Ex03 a b) -> Ex03 <$> m0Key f a <*> m3Key g b),\n   m3Empty = Comp (m0Empty f),\n   m3lk = \\(Ex03 k1 k3) (Comp t) -> do t' <- m0lk f k1 t; m3lk g k3 t',\n   m3upd = \\(Ex03 k1 k3) h (Comp t) -> Comp $ m0upd f k1 (m3upd g k3 h . \\case Just tb -> tb; Nothing -> (m3Empty g)) t,\n   m3showKey = \\(Ex03 k1 k2) -> m0showKey f k1 <> \",\" <> m3showKey g k2\n,\n   m3showTbl = \\h (Comp t) -> m0showTbl f (m3showTbl g h) t\n  }                      \n\nnewtype Lam' (m2 :: (k2 -> k3 -> Type) -> Type) (h :: k1 -> k2 -> k3 -> Type) (a :: k1) = Lam' {fromLam' :: (m2 (h a))}\ndata M12 (m1 :: (k1 -> Type) -> Type) (m2 :: (k2 -> k3 -> Type) -> Type) (h :: k1 -> k2 -> k3 -> Type) = M12 (m1 (Lam' m2 h))\n\ncontaining12 :: (forall v. Map1 k1 m1 f v) -> (forall k4. Map2 k2 m2 g (h k4)) -> Map3 (Sig12 k1 k2) (M12 m1 m2) (Sig12 f g) h\ncontaining12 f g = Map3\n   {\n   m3Key = (\\(Ex12 a b) -> Ex12 <$> m1Key f a <*> m2Key g b),\n   m3Empty = M12 (m1Empty f),\n   m3lk = \\(Ex12 k1 k2) (M12 t) -> do Lam' t' <- m1lk f k1 t; m2lk g k2 t',\n   m3upd = \\(Ex12 k1 k2) h (M12 t) -> M12 $ m1upd f k1 (Lam' . m2upd g k2 h . (\\case Just tb -> tb; Nothing -> m2Empty g) . fmap fromLam') t,\n   m3showKey = \\(Ex12 k1 k2) -> m1showKey f k1 <> \">\" <> m2showKey g k2,\n   m3showTbl = \\h (M12 t) -> m1showTbl f (m2showTbl g h . fromLam') t\n   }\n\n\n\ndata F2m m g h = F2m (forall x y. g x y -> m (h x y))\ndata F2m' m g f h = F2m' (forall x y. g x y -> f x y -> m (h x y))\n\ndata F3m m g h = F3m (forall x y z. g x y z -> m (h x y z))\ndata F3m' m g f h = F3m' (forall x y z. g x y z -> f x y z -> m (h x y z))\n\nmemo2 :: forall g h k m n. MonadIO n => Map2 k m g h -> ((forall x y. g x y -> n (h x y)) -> forall x y. g x y -> n (h x y)) -> n (F2m n g h)\nmemo2 Map2{..} f = do\n    tblRef <- liftIO $ newIORef m2Empty\n    let finished :: forall x y. g x y -> n (h x y)\n        finished arg = do\n          tbl <- liftIO $ readIORef tblRef\n          key <- liftIO $ m2Key arg\n          case m2lk key tbl of\n            Just result -> do\n              -- liftIO $ putStrLn \"memo2: hit\"\n              return result\n            Nothing -> do\n              -- liftIO $ putStrLn \"memo2: miss\"\n              res <- f finished arg\n              liftIO $ modifyIORef tblRef (m2upd key $ \\_ -> res)\n              return res\n    return (F2m finished)\n  \nmemo2' :: forall g f h k m n. MonadIO n => Map2 k m g h -> ((forall x y. g x y -> f x y -> n (h x y)) -> forall x y. g x y -> f x y -> n (h x y)) -> n (F2m' n g f h)\nmemo2' Map2{..} f = do\n    tblRef <- liftIO $ newIORef m2Empty\n    let finished :: forall x y. g x y -> f x y -> n (h x y)\n        finished arg extra = do\n          tbl <- liftIO $ readIORef tblRef\n          key <- liftIO $ m2Key arg\n          case m2lk key tbl of\n            Just result -> do\n              -- liftIO $ putStrLn \"memo2': hit\"\n              return result\n            Nothing -> do\n              -- liftIO $ putStrLn (\"memo2: miss \" <> m2showKey key) --  <> \" from \" <> m3showTbl (const \".\") tbl\n              res <- f finished arg extra\n              liftIO $ modifyIORef tblRef (m2upd key $ \\_ -> res)\n              return res\n    return (F2m' finished)\n\nmemo3' :: forall g f h k m n. MonadIO n => Map3 k m g h -> ((forall x y z. g x y z -> f x y z -> n (h x y z)) -> forall x y z. g x y z -> f x y z -> n (h x y z)) -> n (F3m' n g f h)\nmemo3' Map3{..} f = do\n    tblRef <- liftIO $ newIORef m3Empty\n    let finished :: forall x y z. g x y z -> f x y z -> n (h x y z)\n        finished arg extra = do\n          tbl <- liftIO $ readIORef tblRef\n          key <- liftIO $ m3Key arg\n          case m3lk key tbl of\n            Just result -> do\n              -- liftIO $ putStrLn \"memo3: hit\"\n              return result\n            Nothing -> do\n              -- liftIO $ putStrLn (\"memo3: miss \" <> m3showKey key) --  <> \" from \" <> m3showTbl (const \".\") tbl\n              res <- f finished arg extra\n              liftIO $ modifyIORef tblRef (m3upd key $ \\_ -> res)\n              return res\n    return (F3m' finished)\n"
  },
  {
    "path": "TypedFlow/Models/Topic.hs",
    "content": "{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFoldable #-}\n{-# LANGUAGE DeriveFunctor #-}\n{-# LANGUAGE DeriveTraversable #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE GADTs #-}\n{-# LANGUAGE MagicHash #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE StandaloneDeriving #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeInType #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-|\nModule      : TypedFlow.Models.Topic\nDescription : Topic models\nCopyright   : (c) Jean-Philippe Bernardy, 2017\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n-}\n\n\nmodule TypedFlow.Models.Topic where\nimport Prelude hiding (RealFrac(..))\nimport TypedFlow.TF\nimport TypedFlow.Layers\nimport TypedFlow.Types\nimport TypedFlow.Types.Proofs ((?>), knownSum')\nimport TypedFlow.Learn\nimport GHC.TypeLits\nimport Data.Monoid ((<>))\nimport Data.Proxy\n\n-- | A convolutional document summary function. Described in\n-- 'Topically Driven Neural Language Model' by Lau, Baldwin and Cohn.\ntdlmDocsummary :: forall\n  (vocSize :: Nat) -- number of words\n  (e :: Nat) -- size of the embedding\n  (a :: Nat) -- number of features of the document vector summary \n  (n :: Nat) -- length of the document\n  (filterSize :: Nat) -- size of the convolution filter\n  (t :: NBits) -- size of floats\n  .  KnownNat vocSize => KnownNat filterSize => KnownNat e => KnownNat a => KnownNat n => KnownBits t\n  => (EmbeddingP vocSize e (Flt t))\n  -> (ConvP (Flt t) a e '[filterSize])\n  -> DropProb\n  -> Gen (T '[n] Int32 -> T '[a] (Flt t))\ntdlmDocsummary embs filters dropProb = do\n  drpEmb <- mkDropout dropProb\n  return $ \\document ->\n    let embeddedDoc :: Tensor [n,e] (Flt t)\n        embeddedDoc = mapT (drpEmb . embedding @e @vocSize embs) document\n    in reduceMax axis0 (conv' @'[n] filters embeddedDoc)\n\ntdlmDocsummary' :: forall\n  (vocSize :: Nat) -- number of words\n  (e :: Nat) -- size of the embedding\n  (n :: Nat) -- length of the document\n  -- (a :: Nat) -- number of features of the document vector summary \n  -- (filterSize :: Nat) -- size of the convolution filter\n  spec\n  (t :: NBits) -- size of floats\n  proxy\n  .  KnownNat vocSize => KnownNat (Ap Frst' spec) => KnownNat e => KnownNat (Ap Scnd' spec) => KnownNat n => KnownBits t\n  => proxy spec\n  -> (EmbeddingP vocSize e (Flt t))\n  -> (ConvP (Flt t) (Ap Scnd' spec) e '[(Ap Frst' spec)])\n  -> DropProb\n  -> Gen (T '[n] Int32 -> T '[Ap Scnd' spec] (Flt t))\ntdlmDocsummary' _proxy  = tdlmDocsummary\n\nscnds :: SList xs -> SList (Ap (FMap Scnd') xs)\nscnds Unit = Unit\nscnds (_ :* xs) = Proxy :* scnds xs\n-- hmap _ Unit = Unit\n-- hmap f (x :* xs) = f x :* hmap f xs\n\nmkTdlmDocsummary :: forall\n  (vocSize :: Nat) -- number of words\n  (e :: Nat) -- size of the embedding\n  (spec :: [(Nat,Nat)]) -- (size of the convolution filter,number of features) \n  (n :: Nat) -- length of the document\n  (t :: NBits) -- size of floats\n  .  KnownNat vocSize => KnownNat e => KnownNat n => KnownBits t\n  => All KnownNat (Ap (FMap Scnd') spec)\n  => All KnownNat (Ap (FMap Frst') spec)\n  => SList spec\n  -> DropProb\n  -> Gen (T '[n] Int32 -> T '[Sum (Ap (FMap Scnd') spec)] (Flt t))\nmkTdlmDocsummary xs0 dropProb = case xs0 of\n  Unit -> return (\\_ -> zeros)\n  (proxy :* xs) -> knownSum' (scnds xs) ?>\n                   do embs <- parameterDefault (\"embs_topic_\" ++ show (sListLength xs))\n                      filters <- parameterDefault (\"filters_topic_\" ++ show (sListLength xs))\n                      f <- tdlmDocsummary' @vocSize @e proxy embs filters dropProb\n                      fs <- mkTdlmDocsummary @vocSize @e xs dropProb\n                      return $ \\input -> concat0 (f input) (fs input)\n\n-- | Parameter for topics. This is effectively map from document\n-- features (a) to topic representations (vectors of size b) via k\n-- topic distributions.\ndata TopicP t a k b = TopicP {topicDistributions :: (T '[a,k] (Flt t))  -- ^ a linear map from documents features (a) to topic distributions (k)\n                             ,topicRepresentations :: (T '[k,b] (Flt t)) -- ^ a linear map from topic distributions (k) to topic representations (b)\n                             }\n\ninstance (KnownNat a, KnownNat k, KnownNat b, KnownBits t) => KnownTensors (TopicP t a k b) where\n  travTensor f s (TopicP x y) = TopicP <$> travTensor f (s<>\"_A\") x <*> travTensor f (s<>\"_B\") y\ninstance (KnownNat a, KnownNat k, KnownNat b, KnownBits t) => ParamWithDefault (TopicP t a k b) where\n  defaultInitializer = TopicP <$> glorotUniform <*> glorotUniform\n\n-- | A topic modeler. Described 'Topically Driven Neural Language\n-- Model' by Lau, Baldwin and Cohn.  Returns a function converting raw\n-- representations (eg. document summaries) to topic representations.\n-- This representation can be used as input to a dense layer to\n-- predict a word, or as input to an LSTM (initial state) to predict\n-- sentences.\nmkTdlmTopic :: forall\n  (kk :: Nat) -- number of topics\n  (a :: Nat) -- document vector summary size\n  (b :: Nat) -- topic representation size\n  (t :: NBits) -- size of floats\n  . KnownNat kk => KnownNat a => KnownNat b => KnownBits t\n  => Float -> TopicP t a kk b -> Gen (T '[a] (Flt t) -> (Tensor '[b] (Flt t), Tensor '[kk] (Flt t)))\nmkTdlmTopic separationConstant (TopicP topicInput topicOutput) = do\n  drpS   <- mkDropout (DropProb 0.1)\n  let topicNormalized :: T '[kk,b] (Flt t)\n      topicNormalized = mapT normalize topicOutput\n      -- matrix of correlation between the topics\n      topicCorrelation :: T '[kk,kk] (Flt t)\n      topicCorrelation = matmul topicNormalized (transpose01 topicNormalized)\n      -- max correlation between two distinct topics\n      topicOverlap = reduceMaxAll (square (topicCorrelation ⊝ eye))\n  addRegularizer (constant separationConstant ⊙ cast topicOverlap) -- regularizer which ensures that topics are disjoint\n\n  return (\\d -> let p :: T '[kk] (Flt t)\n                    p = softmax0 (topicInput ∙ d) -- attention distribution (among the topics)\n                in (drpS (topicOutput ∙ p), p))\n\n\n\n-- | Gating unit which can be used to mix a RNN hidden state with an\n-- external information source (eg. topic representation).  Described\n-- 'Topically Driven Neural Language Model' by Lau, Baldwin and Cohn;\n-- formula (3)\ntdlmGatingUnit :: KnownNat n => KnownFloat t => KnownNat m => (GRUP t m n) -> T '[n] t -> T '[m] t -> (T '[m] t)\ntdlmGatingUnit (GRUP wz wr w) s h = \n  let x = concat0 h s\n      z = sigmoid (wz ∙ x)\n      r = sigmoid (wr ∙ x)\n      hTilda = tanh (w ∙ (concat0 (r ⊙ h) s))\n  in ((ones ⊝ z) ⊙ h + z ⊙ hTilda)\n"
  },
  {
    "path": "TypedFlow/Models/Transformer.hs",
    "content": "{-# LANGUAGE PartialTypeSignatures #-}\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFoldable #-}\n{-# LANGUAGE DeriveFunctor #-}\n{-# LANGUAGE DeriveTraversable #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE GADTs #-}\n{-# LANGUAGE MagicHash #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE StandaloneDeriving #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeInType #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-# LANGUAGE NoStarIsType #-}\n{-|\nModule      : TypedFlow.Models.Transformer\nDescription : Topic models\nCopyright   : (c) Jean-Philippe Bernardy, 2020\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n-}\n\n\nmodule TypedFlow.Models.Transformer where\nimport Prelude hiding (RealFrac(..))\nimport TypedFlow.TF\nimport TypedFlow.Abstract\nimport TypedFlow.Layers\nimport TypedFlow.Types\nimport TypedFlow.Types.Proofs ((?>), knownSum')\nimport GHC.TypeLits\n\n-- Convention for type variables:\n-- h = number of heads\n-- e = embedding size\n-- n = sequence length\n\naverage :: forall e. KnownNat e => T '[e] Float32 -> Scalar Float32\naverage = reduceMeanAll\n\n-- | Normalise a vector. But add a small epsilon to avoid division by zero\nnormalizer :: forall e. KnownNat e => T '[e] Float32 -> T '[e] Float32\nnormalizer x = mapT (⊘ (sigma + epsilon)) xmu -- so the norm of result is almost 1\n  where mu = average x\n        xmu = mapT (⊝ mu) x  -- so the average of xmu is 0\n        sigma = sqrt (average (square xmu)) -- the norm of xmu.\n        epsilon = 0.001 -- ?\n\n-- Informally:\n-- mapT f x = vector y such that y_i = f (x_i) -- (the first axis)\n\ndimAsFloat :: forall e. KnownNat e => Float\ndimAsFloat = fromIntegral (knownNatVal (natSat @e))\n\n-- | dot product attention on one key (k)\ndotAttention1 :: forall e n. KnownNat e => KnownNat n\n  => T '[e,n] Float32 -> T '[n,e] Float32 -> T '[e] Float32 -> T '[e] Float32\ndotAttention1 q v k = v ∙ softmax0 (mapT (⊘ normFactor) (q ∙ k))\n  where normFactor = constant (sqrt (dimAsFloat @e))\n\n-- | dot product attention for every position\ndotAttention :: forall n e. KnownNat n => KnownNat e\n  => T '[n,e] Float32 -> T '[n,e] Float32 -> T '[n,e] Float32 -> T '[n,e] Float32\ndotAttention v k q = mapT (dotAttention1 (transpose01 q) v) k\n\n-- | h copies of a dense layer (the same for every copy).\nmultiheadLinearEncoder :: forall h e. KnownNat e => KnownNat h =>\n  String -> Gen (T '[e] Float32 -> T '[h,e] Float32)\nmultiheadLinearEncoder name = do\n  wv <- parameterDefault (\"w_\" ++ name)\n  return $ \\x -> reshape (wv # x)\n\nmultiheadSelfAttentionModule\n  :: forall h n e. KnownNat n => KnownNat h => KnownNat e\n  => String -> Gen (T '[n,e] Float32 -> T '[n,e] Float32)\nmultiheadSelfAttentionModule nm = do\n  ev <- multiheadLinearEncoder @h (\"v\" ++ nm)\n  eq <- multiheadLinearEncoder @h (\"q\" ++ nm)\n  ek <- multiheadLinearEncoder @h (\"k\" ++ nm)\n  w1 <- parameterDefault (\"w1\" ++ nm)\n  -- w2 <- parameterDefault (\"w2\" ++ nm)\n  return $ \\x ->\n    let v = transpose01 (mapT ev x)\n        q = transpose01 (mapT eq x)\n        k = transpose01 (mapT ek x)\n        r :: T '[n,h,e] Float32\n        r = transpose01 (zipWith3T dotAttention q k v)\n        r' = mapT (dense @e w1 . reshape @'[h * e]) r\n    in mapT ({-dense w2 . -}normalizer) (r' + x)\n       -- x + mapT normalizer r'\n\nmultiheadSelfAttentionModuleDecoder\n  :: forall h n e. KnownNat n => KnownNat h => KnownNat e\n  => String -> Gen (T '[n,e] Float32 -> T '[n,e] Float32  -> T '[n,e] Float32)\nmultiheadSelfAttentionModuleDecoder nm = do\n  ev <- multiheadLinearEncoder @h (\"v\" ++ nm)\n  eq <- multiheadLinearEncoder @h (\"q\" ++ nm)\n  ek <- multiheadLinearEncoder @h (\"k\" ++ nm)\n  w1 <- parameterDefault (\"w1\" ++ nm)\n  -- w2 <- parameterDefault (\"w2\" ++ nm)\n  return $ \\x    -- comes from decoder\n            y    -- comes from encoder\n           ->\n    let k = transpose01 (mapT ek y)\n        v = transpose01 (mapT ev x)\n        q = transpose01 (mapT eq y)\n        r :: T '[n,h,e] Float32\n        r = transpose01 (zipWith3T dotAttention q k v)\n        r' = mapT (dense @e w1 . reshape @'[h * e]) r\n    in mapT ({-dense w2 . -}normalizer) (r' + x)\n       -- x + mapT normalizer r'\n\n\nfeedForwardModule :: forall e. KnownNat e\n  => String -> Gen (T '[e] Float32 -> T '[e] Float32)\nfeedForwardModule nm = do\n  w1 :: DenseP Float32 e e <- parameterDefault (nm ++ \"w1\")\n  w2 <- parameterDefault (nm ++ \"w2\")\n  return $ \\x -> normalizer (x + (w2 # relu (w1 # x)))\n\nencoderModule :: forall h n e. KnownNat n => KnownNat h => KnownNat e => DropProb \n  -> String -> T '[n,e] Float32 -> Gen (T '[n,e] Float32 -> T '[n,e] Float32)\nencoderModule dropProb nm positionalTensor = do\n  drp <- mkDropout dropProb\n  selfAtt <- multiheadSelfAttentionModule @h (nm ++ \"mh\")\n  ff <- feedForwardModule (nm ++ \"ff\")\n  return (mapT ff . selfAtt . (+ positionalTensor) . drp)\n\npositionalModuleSinCos :: forall n e. KnownNat e => KnownNat n => T '[n,e] Float32\npositionalModuleSinCos = sin (transpose01 (broadcastT pos) * (broadcastT omega) + broadcastT phase)\n  where pos = (cast (range @n @'B32)) :: T '[n] Float32\n        phase = cast ((range @e @'B32) `floorMod` constant 2) * (constant pi/2) :: T '[e] Float32\n        omega = constant (log 10000) * exp (constant (-2.0 / dimAsFloat @e) * cast (range @e @'B32))\n        -- Note I'm not dividing the frequence by 2 because integer\n        -- division isn't implemented. Should not have any consequence.\n\npositionalModuleLearned :: KnownNat e => KnownNat n => Gen (T '[n,e] Float32)\npositionalModuleLearned = do\n  e <- parameterDefault \"positional\"\n  return $ let EmbeddingP x = e in x\n\nencoderStack :: forall h n e. KnownNat h => KnownNat n => KnownNat e\n  => DropProb -> Int -> Gen (T '[n,e] Float32 -> T '[n,e] Float32)\nencoderStack dropProb n = do\n  p <- positionalModuleLearned\n  encoders <- mapM (\\i -> encoderModule @h dropProb (\"enc\" ++ show i) p) [1..n]\n  return (foldr (.) id encoders) -- n-ary function composition\n"
  },
  {
    "path": "TypedFlow/Python.hs",
    "content": "{-# LANGUAGE ViewPatterns #-}\n{-|\nModule      : TypedFlow.Python\nDescription : Python-generation Functions \nCopyright   : (c) Jean-Philippe Bernardy, 2017\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n\n-}\n\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFoldable #-}\n{-# LANGUAGE DeriveFunctor #-}\n{-# LANGUAGE DeriveTraversable #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE GADTs #-}\n{-# LANGUAGE GeneralizedNewtypeDeriving #-}\n{-# LANGUAGE LambdaCase #-}\n{-# LANGUAGE MagicHash #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# LANGUAGE OverloadedStrings #-}\n{-# LANGUAGE PatternSynonyms #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE StandaloneDeriving #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeInType #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE UndecidableSuperClasses #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n\nmodule TypedFlow.Python (compile, compileGen, generateFile) where\n\nimport Data.Char (toLower)\nimport Data.Proxy\nimport Data.List (genericReplicate, )\nimport GHC.TypeLits\nimport Control.Monad.State\nimport TypedFlow.Types\nimport TypedFlow.Broadcast (permToFun,unopInputShape)\nimport TypedFlow.Types.Proofs\nimport TypedFlow.Memo\nimport Prettyprinter as PP\nimport Prettyprinter.Render.String as PP\nimport qualified Data.Map as M\nimport TypedFlow.Learn\nimport qualified Data.Sequence as S\nimport Data.Sequence (Seq, (|>), )\nimport Data.Foldable\n\nfirst :: (t -> a) -> (t, b) -> (a, b)\nfirst f (x,y) = (f x,y)\n\nparamShape' :: VarInfo -> [Integer]\nparamShape' (VarInfo {varRef = Ref _ s _}) = shapeToList' s\n\nparamDType ::  VarInfo -> Typ\nparamDType (VarInfo {varRef = Ref _ _ t}) = sTypTyp t\n\nparamName :: VarInfo -> String\nparamName (VarInfo {varRef = Ref {..}}) = refName\n\n\ngenerateFile :: String -> Python [VarInfo] -> IO ()\ngenerateFile fname g = do\n  putStrLn (\"Parameters (total \" ++ show (sum [product (paramShape' p) | p <- params]) ++ \"):\")\n  forM_ params printParam\n  writeFile fname output\n  where (output,params) = generate g\n        printParam p = putStrLn (paramName p ++ \": \" ++ \"T \" ++ renderSimple (showShape' (paramShape' p))  ++ \" \" ++ showT (paramDType p))\n\nnamed :: String -> DOC -> DOC\nnamed fname x = text (fname <> \"=\") <> x\n\ntext :: String -> DOC\ntext = pretty\n\ngenFun :: forall b. String -> [DOC] -> Python b -> Python b\ngenFun name args body = do\n  gen (text \"def \" <> text name <> align (tuple args) <> text \":\")\n  withDOC (\\b -> \"  \" <> align b) body\n\n\nshowTyp :: forall t. KnownTyp t => DOC\nshowTyp = text (showT (typVal @t))\n\nshowSTyp :: forall t. STyp t -> DOC\nshowSTyp t = knownTyp t $ showTyp @t\n\nshowT :: Typ -> [Char]\nshowT (Typ Bool _) = \"tf.bool\"\nshowT (Typ Cmplx B32) = \"tf.complex64\"\nshowT (Typ Cmplx B64) = \"tf.complex128\"\nshowT (Typ k l) = \"tf.\" ++ map toLower (show k) ++ drop 1 (show l)\n\nshowShape' ::  [Integer] -> DOC\nshowShape' s = list (map (showDim' \"None\") s)\n\nshowShape :: ∀ (s :: Shape). All KnownNat s => SList s -> DOC\nshowShape s = showShape' (shapeToList'' s)\n\nshowSShape :: ∀ (s :: Shape). SShape s -> DOC\nshowSShape s = showShape' (shapeToList' s)\n\nshowShapeType :: ∀ (s :: Shape). KnownShape s => DOC\nshowShapeType = showSShape (typeSShape @s)\n\n-- | Show a shape, but \"None\" is replaced by \"-1\"\nshowShapeMinus :: forall (s::Shape) proxy. All KnownNat s => SList' proxy s -> DOC\nshowShapeMinus s = list (map (showDim' \"-1\") (shapeToList'' s))\n\nshowShapeLen :: ∀ (s::Shape). KnownLen s => DOC\nshowShapeLen = (text . show) (listTypeLen @ s)\n\nshowDim' :: String -> Integer -> DOC\nshowDim' none n = text (if n == 514229 then none else show n)\n\nshowDimM :: forall n. KnownNat n => DOC\nshowDimM = showDim' \"-1\" (natVal (Proxy @ n))\n\nshowDim :: forall n. KnownNat n => DOC\nshowDim = showDim' \"None\" (natVal (Proxy @ n))\n\nshowDimS :: forall n. Sat KnownNat n -> DOC\nshowDimS Sat = showDim @n\n\ngen :: DOC -> Python ()\ngen s = modify $ \\PyState{..} -> PyState {genText=genText |> s,..}\n\nsetGen :: Seq DOC -> Python ()\nsetGen d = modify $ \\PyState{..} -> PyState {genText=d,..}\n\n(<--) :: Ref Int s t -> UntypedExpression -> Python ()\nx <-- y = gen (pyVarRepr x <> text \"=\" <>  y)\n\n\nrenderSimple :: Doc ann -> String\nrenderSimple = renderString . layoutPretty (LayoutOptions Unbounded)\n\n-- | save an intermediate result to a variable and save it to\n-- genAssignTable for future re-use.\ncache :: forall s t. KnownTyp t => KnownShape s => DOC  -> Python DOC\ncache x = do\n  let x' = renderSimple x\n  mcache <- M.lookup x' <$> gets genAssignTable\n  case mcache of\n    Just y -> do\n      -- comment (\"cache hit: \" <> text x')\n      return y\n    Nothing -> do\n      -- comment (\"cache miss\")\n      v <- newPyVar @s @t\n      comment (\"shape: \" <> (showShapeType @s))\n      v <-- x\n      modify $ (\\g -> g {genAssignTable = M.insert x' (pyVarRepr v) (genAssignTable g)})\n      return (pyVarRepr v)\n\nnewPyVar' :: forall s t. SShape s -> STyp t -> Python (Ref Int s t)\nnewPyVar' s t = knownSShape s ?> (knownTyp t $ newPyVar @s @t)\n\nnewId :: Python Integer\nnewId = do\n  n <- gets genId\n  modify $ \\PyState{..} -> PyState {genId=genId+1,..}\n  return n\n\nnewPyVar :: forall s t. KnownShape s => KnownTyp t => Python (Ref Int s t)\nnewPyVar = do\n  n <- newId\n  return $ Ref (fromIntegral n) typeSShape typeSTyp\n\npyVarInfoRepr :: VarInfo -> DOC\npyVarInfoRepr i = text (varName i)\n\npyVarRepr :: Ref Int s t -> DOC\npyVarRepr (Ref n _ _) = text (\"var\" <> show n)\n\ntuple :: [DOC] -> DOC\ntuple = parens . align . sep . punctuate comma\ndict :: [(String,DOC)] -> DOC\ndict xs = braces $ align $ sep $ punctuate comma [text (show k) <> \":\" <> v | (k,v) <- xs]\n\nfuncall :: String -> [DOC] -> DOC\nfuncall = funcall' . text\n\nfuncall' :: DOC -> [DOC] -> DOC\nfuncall' f args =  f <> tuple args\n\ncomment :: DOC -> Python ()\ncomment c = gen (\"#\" <> c)\n\nfunc :: String -> [DOC] -> [(String,DOC)] -> DOC\nfunc fname positional namedArgs = funcall fname (positional ++ map (uncurry named) namedArgs )\n\nwithDOC :: forall a. (DOC -> DOC) -> Python a -> Python a\nwithDOC f g = do\n  before <- gets genText\n  setGen mempty\n  x <- g\n  after <- gets genText\n  setGen (before |> f (vcat $ toList after))\n  return x\n\ngenerate :: Python [VarInfo] -> (String,[VarInfo])\ngenerate s = (renderString (layoutPretty (LayoutOptions (AvailablePerLine 92 1)) (vcat $ toList genText)),\n              genPyVars)\n  where (genPyVars,PyState{..}) = runState s initPyState\n        initPyState = PyState {genPureTable = mempty\n                              ,genAssignTable = mempty\n                              ,genText = mempty\n                              ,genId = 10000}\n\ngeneratePure :: forall s t. KnownTyp t => KnownShape s => T s t -> Python DOC\ngeneratePure x = do\n  let sn = makeSn2 x\n  mv <- snMapLookup2 sn <$> gets genPureTable\n  case mv of\n    Just v -> do\n        -- comment (\"gp hit:\" <> v)\n        return v\n    Nothing -> do\n      -- comment (\"gp miss\")\n      e <- generatePure' (\\s x' -> knownSShape s ?> generatePure x') typeSShape x\n      v <- cache @s @t e\n      modify (\\g -> g {genPureTable = (snMapInsert2 sn v) (genPureTable g)})\n      return v\n\ngenDistr :: forall s s0 t. KnownTyp t => Distribution s t -> SShape s0 -> SShape s -> DOC\ngenDistr d sh s1 = case d of\n  TruncatedNormalD stddev -> funcall \"tf.random.truncated_normal\"\n    [showSShape (sh .+. s1), named \"stddev\" (float stddev), named \"dtype\" (showTyp @t)]\n  UniformD low high -> funcall \"tf.random.uniform\" [showSShape (sh .+. s1)\n                                ,named \"minval\" (float low)\n                                ,named \"maxval\" (float high)\n                                ,named \"dtype\" (showTyp @t)]\n  OrthogonalD ->\n    funcall' (funcall \"tf.keras.initializers.orthogonal\" []) [named \"dtype\" (showTyp @t), named \"shape\" (showSShape (sh .+. s1))]\n\ngeneratePure' :: forall s t. KnownTyp t => (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> Python DOC) -> SShape s -> T s t -> Python DOC\ngeneratePure' rec sR = knownSShape sR ?> \\case\n  Unbroadcast{} -> error \"broadcasting operation did not complete (Unbroadcast)!\"\n  BroadcastT _ _ _ sh x -> --- error \"broadcasting operation did not complete (BroadcastT)!\"\n    do\n     -- debug help\n     rx <- rec sh x\n     return (funcall \"ERROR:BroadcastT\" [rx])\n  MapT {} -> error \"broadcasting operation did not complete (mapT)!\"\n  ZipT {} -> error \"broadcasting operation did not complete (ZipT)!\"\n  Zip3T {} -> error \"broadcasting operation did not complete (Zip3T)!\"\n  If c x y -> do\n    rc <- rec typeSShape c\n    rx <- rec typeSShape x\n    ry <- rec typeSShape y\n    return (func \"tf.cond\" [rc] [(\"true_fn\", lambda0 rx) ,(\"false_fn\", lambda0 ry)])\n    where lambda0 z = text \"lambda: \" <> z\n  -- if broadcast_to is broken: https://github.com/tensorflow/tensorflow/issues/21901\n  -- DirectBroadcast s0 s1 s2 s3 x -> do\n  --  recx <- rec (s0 .+. s2) x\n  --  let expanded = func \"tf.reshape\" [recx,list (map (showDim' \"-1\")\n  --         (concat [shapeToList' s0, genericReplicate (sListLength s1) 1\n  --                 ,shapeToList' s2, genericReplicate (sListLength s3) 1 ]))] []\n  --  return (funcall \"tf.add\" [expanded, func \"tf.zeros\" [showSShape sR] [(\"dtype\", showTyp @t)]])\n  DirectBroadcast s0 s1 s2 s3 x -> do\n   recx <- rec (s0 .+. s2) x\n   let expanded = func \"tf.reshape\" [recx,list (map (showDim' \"-1\")\n          (concat [shapeToList' s0, genericReplicate (sListLength s1) 1\n                  ,shapeToList' s2, genericReplicate (sListLength s3) 1 ]))] []\n   return (funcall \"tf.broadcast_to\" [expanded, showSShape sR])\n  Noise noiseId s0 s1 x -> do\n    return $ (genDistr x s0 s1) <+> (text \"# \" <> integer noiseId)\n  T op -> return $ case op of\n    ExternalVar (Ref v _ _) -> text v\n    Variable v -> pyVarRepr v\n    (Constant c) -> funcall \"tf.constant\" [prettyT @t c, named \"shape\" (showSShape sR), named \"dtype\" (showTyp @t)]\n    (Range n@Sat) -> (func \"tf.range\" [] [(\"start\",integer 0),\n                               (\"limit\",integer (natVal n)),\n                               (\"dtype\",showTyp @t)])\n  Where c x y -> do\n    rc <- rec typeSShape c\n    rx <- rec typeSShape x\n    ry <- rec typeSShape y\n    return (funcall \"tf.where\" [rc, rx, ry])\n  UnOp operation s0  x -> do\n   recx <- rec (s0 .+. unopInputShape operation) x\n   return $ case operation of\n    Diag _ -> funcall \"tf.matrix_diag\" [recx]\n    Cast -> funcall \"tf.cast\" [recx,showTyp @t]\n    StopGradient -> funcall \"tf.stop_gradient\" [recx]\n    ExpM _  -> funcall \"tf.linalg.expm\" [recx]\n    ZeroTriangle _ side k  -> funcall (\"tf.experemental.numpy.tri\" ++ case side of Upper -> \"u\"; Lower -> \"l\") [recx, integer k]\n    Conjugate -> funcall \"tf.math.conj\" [recx]\n    RealPart -> funcall \"tf.math.real\" [recx]\n    Axis1Op _ (SliceOp _ _ lo hi) -> recx <> list (replicate (fromIntegral (sListLength s0)) (text \":\") ++ [integer lo <> text \":\" <> integer hi])\n    Axis1Op _ (AccessOp _ idx) -> recx <> list (replicate (fromIntegral (sListLength s0)) (text \":\") ++ [integer idx])\n    Axis1Op _ op' ->\n       let (op,args) = case op' of\n                         SliceOp {} -> error \"Python: panic: sliceop is special\"\n                         AccessOp {} -> error \"Python: panic: accessop is special\"\n                         ReverseT _ -> (\"tf.reverse\",[])\n                         OneHot depth -> (\"tf.one_hot\",[(\"dtype\",showTyp @t), (\"depth\", showDimS depth)])\n                         ArgMax{} -> (\"tf.argmax\",[(\"output_type\",showTyp @t)])\n                         ReduceOp _ r -> (\"tf.reduce_\" ++ rop, [])\n                            where rop = case r of\n                                           Max -> \"max\"\n                                           Min -> \"min\"\n                                           Sum -> \"sum\"\n                                           Mean -> \"mean\"\n           axisName = if op == \"tf.nn.softmax\" then \"dim\" else \"axis\"  -- use dim before TF 1.5\n           useAxisList = case op' of ReverseT _ -> True; _ -> False\n       in func op [recx] ((axisName,(if useAxisList then (list . (:[])) else id) (integer (sListLength s0))):args)\n    Float1Op op' -> funcall op (recx:args)\n       where (op,args) = case op' of\n                HardSigmoid -> (\"tf.keras.backend.hard_sigmoid\",[])\n                Relu -> (\"tf.nn.relu\",[])\n                ClipByValue lo hi -> (\"tf.clip_by_value\",[float lo,float hi])\n                _ -> (\"tf.\" ++ map toLower (show op'), [])\n    Num1Op op' -> funcall op (recx:args)\n       where (op,args) = case op' of\n                Negate -> (\"tf.negative\",[])\n                _ -> (\"tf.\" ++ map toLower (show op'), [])\n  MatMul s0 a b c x y  -> do\n    recx <- rec (s0 .+. (:*) a ((:*) b Unit)) x\n    recy <- rec (s0 .+. (:*) b ((:*) c Unit)) y\n    return (funcall \"tf.matmul\" [recx, recy])\n  BinOp operation s0 s1 _ s2 _ x y -> do\n   recx <- rec (s0 .+. s1) x\n   recy <- rec (s0 .+. s2) y\n   return $ case operation of\n     Simple2Op sop  -> let pop = case sop of\n                                   MkComplex -> \"tf.complex\"\n                                   Add -> \"tf.add\"\n                                   Divide -> \"tf.divide\"\n                                   IntegerDiv -> \"tf.math.floordiv\"\n                                   Equal -> \"tf.equal\"\n                                   Subtract -> \"tf.subtract\"\n                                   Multiply -> \"tf.multiply\"\n                                   Minimum -> \"tf.minimum\"\n                                   Maximum -> \"tf.maximum\"\n                                   Comparision op -> \"tf.math.\" ++ case op of\n                                     Less -> \"less\"\n                                     Greater -> \"greater\"\n                                     LessOrEqual -> \"less_equal\"\n                                     GreaterOrEqual -> \"greater_equal\"\n                                   Logic op -> \"tf.math.logical_\" ++ case op of\n                                      And -> \"and\"\n                                      Or -> \"or\"\n                                   FloorMod -> \"tf.math.floorMod\"\n                       in funcall pop [recx,recy]\n     SigmoidCrossEntropyWithLogits -> func \"tf.nn.sigmoid_cross_entropy_with_logits\" [] [(\"labels\",recx),(\"logits\",recy)]\n     SparseSoftmaxCrossEntropyWithLogits -> func \"tf.nn.sparse_softmax_cross_entropy_with_logits\" []  [(\"labels\",recx),(\"logits\",recy)]\n     SoftmaxCrossEntropyWithLogits -> func \"tf.nn.softmax_cross_entropy_with_logits\" []   [(\"labels\",recx),(\"logits\",recy)] -- FIXME: use _v2 for TF 1.5\n  ReshapeFrom s t -> do\n    rt <- rec s t\n    return (funcall \"tf.reshape\" [rt, showShapeMinus sR])\n  Concat s0 s1 xs -> do\n    let go :: forall s0 s1 ns. SShape s0 -> SShape s1 -> NP (Catable s0 s1 t) ns -> Python [DOC]\n        go _ _ Unit = return []\n        go s0' s1' (Catable n y :* ys) = (:) <$> rec (s0' .+. n :* s1') y <*> go s0' s1' ys\n    rxs <- go s0 s1 xs\n    return (funcall \"tf.concat\" [list rxs, text \"axis=\" <> integer (sListLength s0)])\n  Transpose s p x -> do\n    rx <- rec s x\n    comment (\"transpose: p = \" <> text (show p) <> \"; \" <> text (show s))\n    return (func \"tf.transpose\" [rx] [(\"perm\",list (map (integer . permToFun p) [0.. sListLength s-1]))])\n  Gather indexShape s0 m s1 x ix -> do\n    rx <- rec (s0 .+. ((:*) m s1)) x\n    rix <- rec (s0 .+. indexShape) ix\n    return (func \"tf.gather\" [named \"params\" rx, named \"indices\" rix, named \"batch_dims\" (integer (sListLength s0)), named \"axis\" (integer (sListLength s0))] [])\n  GatherND containerShape elementShape indexShape x ix -> do\n    rx <- rec (containerShape .+. elementShape) x\n    rix <- rec (indexShape *: (sListLenAsNat containerShape)) ix\n    return (func \"tf.gather_nd\" [rx, rix] [])\n  Convolution bs inChans outChans filterShape s0 x filters -> do\n    recx <- rec ((:*) bs (s0 *: inChans)) x\n    recFilters <- rec (filterShape .+. ((:*) inChans ((:*) outChans Unit))) filters\n    return (func \"tf.nn.convolution\" [recx, recFilters] [(\"padding\",text (show (\"SAME\"::String))),(\"data_format\", text (show dataFormat))])\n   where dataFormat = case sListLength filterShape of\n           1 -> (\"NWC\" :: String)\n           2 -> \"NHWC\"\n           3 -> \"NDHWC\"\n           _ -> error \"convolution: more than 3 spatial dimensions are not supported!\"\n  Pool bs window typ numChans outSpatial x -> do\n     rx <- rec ((:*) bs (zipWithMulSShapes window outSpatial .+. (:*) numChans Unit)) x\n     return (func \"tf.nn.pool\"\n                  [rx, showSShape window, typ']\n                  [(\"strides\", showSShape window),\n                   (\"padding\",text (show (\"SAME\" :: String)))])\n   where typ' = text $ (show $ case typ of MaxPool -> \"MAX\"; AvgPool -> \"AVG\" :: String)\n  Softmax _ _ x -> do\n     rx <- rec typeSShape x\n     return $ func \"tf.nn.softmax\" [rx] [(\"axis\",\"1\")]\n  -- _ -> error \"Python compiler: case not covered\"\ntype Python a = State PyState a\n\ngenerateParameters :: [VarInfo] -> Python [DOC]\ngenerateParameters genVars = do\n  -- generate variables\n  forM genVars $ \\v -> case v of\n      VarInfo {..} -> case varRef of\n        Ref refId shap typ -> do\n          ii <- case varInitial of\n            Nothing -> return []\n            Just iii -> do\n              iiii <- case knownSShape shap of\n                Sat -> knownTyp typ $ generatePure iii\n              return [named \"initial_value\" iiii]\n          var <- newPyVar' shap typ\n          var <-- funcall \"tf.Variable\" ([named \"name\" (string refId), named \"trainable\" (bool varTrainable)] ++ ii)\n          return (pyVarRepr var)\n\n-- | Clip a gradient\nclipByGlobalNorm :: Float -> UntypedExpression -> UntypedExpression\nclipByGlobalNorm maxNorm x = funcall \"tf.clip_by_global_norm\" [x,float maxNorm] <> brackets (int 0)\n -- clip_by_global_norm returns a couple (clipped grads, global_norm)\n\n-- | Gradient of wrt. given parameters.\ngrad :: UntypedExpression -> UntypedExpression -> UntypedExpression\ngrad y vars = funcall \"tf.gradients\" [y, vars]\n\n\nfnToPython ::[VarInfo] -> PreparedFunction -> Python ()\nfnToPython params PreparedFunction{pfInputs = SomeSuch placeHolders,\n                                   pfOutputs = SomeSuch returned,..} = do \n  -- we can't re-use intermediate computations from initialisers or other functions:\n  modify $ \\PyState {..} -> PyState {genPureTable = mempty, genAssignTable = M.empty,..}\n  gen (text \"@tf.function\")\n  genFun (pfName <> \"_fn\") (text \"training_placeholder\":\n                  map pyVarInfoRepr params ++\n                  hMapToList @KnownPlaceholder (text . placeholderName) placeHolders) $\n    do returns <- hfor @KnownPlaceholder returned $ \\ph@(PHT x) -> do\n         r <- generatePure x\n         return (placeholderName ph,r)\n       gen (text \"return \" <> dict returns)\n       return ()\n  gen (text pfName <> \" = \" <>\n        dict [\n          (\"function\",text pfName <> \"_fn\"),\n          (\"batched\",bool pfBatched),\n          (\"placeholders\",dict (hMapToList @KnownPlaceholder\n        (\\ph -> case placeHolderRef ph of\n                  Ref nm shape typ ->\n                    (nm, dict [(\"shape\",showSShape shape), (\"dtype\",showSTyp typ)]))\n        placeHolders))])\n  return ()\n  \ntoPython :: PreparedModel -> Python ()\ntoPython PreparedModel {..} = do\n  gen (text \"import tensorflow as tf\")\n  -- Static stuff: construct and initialise parameters, list placeholders, etc.\n  genFun \"mkModel\" [] $ do\n    vs <- generateParameters pmParams\n    gen (text \"return \" <>\n         dict [(\"batch_size\", integer pmBatchSize)\n              ,(\"parameters\",list vs)\n              ,(\"paramsdict\",dict [(varName p, v) | (p,v) <- zip pmParams vs])])\n  -- Loss/Accur/Predict function\n  forM_ pmFunctions (fnToPython pmParams)\n  return ()\n\n-- | Batchify and compile a model with simple input to output mapping.\ncompile :: forall batchSize sx tx sy ty sy_ ty_ p\n        .  (KnownNat batchSize, KnownShape sx, KnownTyp tx, KnownShape sy, KnownTyp ty, KnownShape sy_, KnownTyp ty_, KnownShape p, KnownLen p)\n        => Options\n        -> Gen (Tensor sx tx -> Tensor sy ty -> ModelOutput  ty_ p sy_)\n        -> Python [VarInfo]\ncompile options fGen = knownSShape (typeSShape @sy_ .+. typeSShape @p) ?> compileGen @batchSize options (sequenceA [simpleModel @p <$> fGen])\n\n-- | Batchify and compile a model with generic  input to output mapping and states\ncompileGen :: forall bs. (KnownNat bs)\n           => Options\n           -> Gen [Function]\n           -> Python [VarInfo]\ncompileGen options model = toPython pm >> return pmParams\n  where pm@PreparedModel{..} = prepare @bs model\n\n\n\nprettyT :: forall t. KnownTyp t => HaskType t -> DOC\nprettyT = case kindVal @(TypKind t) of\n  SInt -> case bitsVal @(TypBits t) of\n    SB32 -> int . fromIntegral\n    SB64 -> int . fromIntegral\n  SBool -> bool\n  SFloat -> case bitsVal @(TypBits t) of\n    SB32 -> float\n    SB64 -> double\n\n\n\ndata PyState = PyState {genId :: Integer\n                       ,genText :: S.Seq DOC\n                       ,genPureTable :: SSNMap2 Shape Typ T DOC\n                       -- ^ Table mapping pointers to their\n                       -- interpretations, so that sharing in the data\n                       -- structures can be exploited when generating\n                       ,genAssignTable :: M.Map String DOC\n                       -- ^ Table mapping expressions to variables, so\n                       -- that lost sharing can be recovered\n                       -- genPeeks :: [(String,UntypedExpression)]\n                       }\n\ntype UntypedExpression = DOC\ntype DOC = Doc ()\n\ndouble :: Double -> DOC\ndouble = pretty\nfloat :: Float -> DOC\nfloat = pretty\ninteger :: Integer -> DOC\ninteger = pretty\nint :: Int -> DOC\nint = pretty\nbool :: Bool -> DOC\nbool = pretty\nstring :: String -> DOC\nstring = dquotes . text \n"
  },
  {
    "path": "TypedFlow/TF.hs",
    "content": "{-# LANGUAGE InstanceSigs #-}\n{-|\nModule      : TypedFlow.TF\nDescription : Binding to tensorflow functions\nCopyright   : (c) Jean-Philippe Bernardy, 2017\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n\nThis module provides direct access to the most commonly used\nTensorFlow functions. Higher-level functions are not defined here.\n-}\n\n{-# LANGUAGE LambdaCase #-}\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFoldable #-}\n{-# LANGUAGE DeriveFunctor #-}\n{-# LANGUAGE DeriveTraversable #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE GADTs #-}\n{-# LANGUAGE MagicHash #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE StandaloneDeriving #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeInType #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-# LANGUAGE ApplicativeDo #-}\n{-# LANGUAGE NoStarIsType #-}\n\nmodule TypedFlow.TF (\n  -- * Variables, Parameters\n  -- ** Parameters\n  parameter',\n  parameter,\n  parameterDefault,\n  ParamWithDefault(..),\n  -- getParameters,\n  -- ** Persistent variables\n  persistent,\n  modifyPersistent,\n  -- ** Placeholders and outputs\n  -- placeholder,\n  -- peekAt,\n  -- peekAtMany,\n  -- * Operations\n  -- ** Constants\n  zeros,\n  ones,\n  eye,\n  constant,\n  -- ** indexwise unary operators\n  round, sigmoid, relu, floor, square,\n  -- ** Indexwise binary operators\n  addN, (⊕), (⊝), (⊙), (⊘), equal,\n  minT, maxT,\n  -- ** Products\n  (∙), (·), matmul,\n  -- ** Reducers\n  reduceMeanAll, reduceSumAll, reduceMinAll, reduceMaxAll,\n  reduceSum, reduceMean, reduceMin, reduceMax,\n  -- argmax,\n  argmax0, argmax1,\n  softmax0, softmax1,\n  -- ** Gradients\n  -- grad,\n  -- clipByGlobalNorm,\n  clipByValue,\n  -- ** Indexing\n  last0, nth0, nth0', lookupT, lookupManyT, gather, range, reverseT,\n  -- ** Split and concatenate\n  slice, slice0, slice1,\n  litStack0,\n  stack0, unstack0,\n  stack1,\n  concatT, concat0, concat1,\n  consT0, snocT0,\n  headT0, tailT0, initT0,\n  -- ** Reshaping\n  expandDim,\n  expandDim0, squeeze0,\n  expandDim1, \n  flatten2, flatten3, flatten12, flattenN2,\n  inflate2, inflate3, inflate12,\n  reshape, flattenAll, inflateAll,\n  -- ** Transposition\n  transposeN, transposeN', transpose01, transposeN01,\n  -- ** Sequences\n  sequenceMask,\n  -- ** Convolutions\n  convolution,\n  -- ** Misc\n  norm, normalize,\n  stopGradient,\n  cast,\n  oneHot0, oneHot1,\n  -- ** complex numbers\n  expm, conjugate, realPart,\n  -- ** Triangular and band Matrices\n  tril, triu, fillTriangular, fillUpperTriangular,\n  -- ** Testing conditions\n  if_, where_, lessThan,\n  -- * Contrib\n  -- ** Mapping\n  mapT, zipWithT, zipWith3T,\n  mapTT, zipWithTT,\n  -- ** Losses\n  sigmoidCrossEntropyWithLogits,\n  softmaxCrossEntropyWithLogits,\n  sparseSoftmaxCrossEntropyWithLogits,\n  -- ** Initializers\n  noise,\n  Distribution(..),\n  varianceScaling, glorotUniform,\n\n  -- ** Heterogeneous vectors\n  repeatT,\n\n  -- ** Heterogeneous heterogeneous vectors\n  repeatHT\n  ) where\n\nimport Prelude hiding (RealFrac(..))\nimport GHC.TypeLits\nimport Data.Proxy\nimport TypedFlow.Types\nimport TypedFlow.Types.Proofs\nimport TypedFlow.Abstract\nimport TypedFlow.Broadcast\n\n-- | Repeat a flexible-shape constant vector to form a heterogeneous tensor vector.\nrepeatT :: forall (ss :: [Shape]) t. All KnownShape ss => KnownLen ss =>\n           (forall s. KnownShape s => T s t) -> HTV t ss\nrepeatT f = zs (typeSList @ss)\n  where zs :: forall (s :: [Shape]). All KnownShape s => SList s -> HTV t s\n        zs Unit = Unit\n        zs (_ :* n) = F f :* zs n\n\n-- | Repeat a flexible-shape constant vector to form a heterogeneous tensor vector.\nrepeatHT :: forall ss. All KnownPair ss => KnownLen ss =>\n           (forall s t. KnownShape s => KnownTyp t => T s t) -> HHTV ss\nrepeatHT f = zs (typeSList @ss)\n  where zs :: forall s. All KnownPair s => SList s -> HHTV s\n        zs Unit = Unit\n        zs (_ :* n) = Uncurry f :* zs n\n\n-- | Declare a parameter to optimize.\nparameter' :: ∀ (shape :: Shape) t. (KnownTyp t,KnownShape shape) => String -> T shape t -> Gen (T shape t)\nparameter' = persistent True\n\n-- | Create a parameter.\nparameter :: forall p. KnownTensors p => String -> Gen p -> Gen p\nparameter s p = travTensor parameter' s =<< p\n\n-- | Declare variable which persists between calls to session.run.\npersistent :: ∀ (shape :: Shape) t. (KnownTyp t,KnownShape shape) => Bool -> String -> T shape t -> Gen (T shape t)\npersistent trainable name initial = do\n  T . ExternalVar <$> GPVariable trainable name (Just initial)\n\n\n-- | Modify a mutable tensor. Attention: for the assignment to happen,\n-- the resulting tensor must be evaluated!\nmodifyPersistent :: (KnownShape s,KnownTyp t) => T s t -> T s t -> Gen (T s t)\nmodifyPersistent (T (Variable v)) x = GPModify v x -- FIXME: pattern matching here is poor style.\n\n-- type family AddSpatialDims xs ys where\n--   AddSpatialDims '[x] '[] = '[x]\n--   AddSpatialDims (x ': xs) (y ': ys) = (x+(y-1)) ': AddSpatialDims xs ys\n\n-- -- | Convolution operation with no padding (applying the filter only on positions where the input is fully defined)\n-- convolutionValid :: forall outputChannels filterSpatialShape inChannels s t.\n--                KnownLen filterSpatialShape\n--             => Length filterSpatialShape <= 3\n--             => ((1 + Length filterSpatialShape) ~ Length s) -- the last dim of s is the batch size\n--             => T (inChannels ': AddSpatialDims s filterSpatialShape) t -- ^ input tensor (batched)\n--             -> T ('[outputChannels,inChannels] ++ filterSpatialShape) t -- ^ filters\n--             -> T (outputChannels ': s) t\n-- convolutionValid = untypedConvolution \"VALID\"\n\n-- poolNC :: forall dim s inputSpatialShape channels batchSize t.\n--                   (inputSpatialShape ~ Take dim s, '[batchSize] ~ Drop dim s) =>\n--                   T ('[channels] ++ s) t ->\n--                   Vec dim  -> String -> String -> \n--                   T ('[channels] ++ s) t\n-- poolNC (T input) windowShape poolingType padding =\n--    T (funcall \"tf.nn.pool\" [input,list (map float (vecToList windowShape)),text poolingType,text padding,named \"data_format\" (text \"NWC\")])\n\n-- Difficulty: relate windowSize, inputSpatialShape, outputSpatialShape\n\n\n\n\n---------------------------\n-- Contrib\ndata VarianceScaleMode = VSFanIn | VSFanOut | VSAvg\ndata Distrib = NormalDistr | UniformDistr\n\n-- | Random tensor with variance scaling according to deeplearning lore.\nvarianceScaling :: forall inDim outDim t. KnownNat inDim => (KnownNat outDim, KnownFloat t) =>\n   Float -> VarianceScaleMode -> Distrib -> Gen (Tensor '[inDim,outDim] t)\nvarianceScaling factor mode distr = noise $ case distr of\n                                   UniformDistr -> UniformD (-limit) limit\n                                   NormalDistr -> TruncatedNormalD limit\n  where\n    fan_in = fromIntegral (natVal (Proxy @inDim))\n    fan_out = fromIntegral (natVal (Proxy @outDim))\n    n = max 1 $ case mode of\n                  VSFanIn -> fan_in\n                  VSFanOut -> fan_out\n                  VSAvg -> (fan_in + fan_out) / 2\n    limit = sqrt ((case distr of NormalDistr -> 1.3; UniformDistr -> 3) * factor / n)\n\n\nglorotUniform :: forall inDim outDim t. KnownNat inDim => (KnownNat outDim, KnownBits t) => Gen (Tensor '[outDim,inDim] ('Typ 'Float t))\nglorotUniform = varianceScaling 1 VSAvg UniformDistr\n\n-- | 'cons' an element and an array (in the first dimension)\nconsT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n =>  T s t -> T (n ': s) t -> T (n+1 ': s) t\nconsT0 x xs = plusComm @1 @n #> concat0 (expandDim0 x) xs\n\n-- | 'snoc' an element and an array (in the first dimension)\nsnocT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n =>  KnownLen s => T (n ': s) t -> T s t -> T (n+1 ': s) t\nsnocT0 xs x = concat0 xs (expandDim0 x)\n\nheadT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n =>  T (n+1 ': s) t -> T (s) t\nheadT0 xs = nth0 0 xs\n\ntailT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n =>  T (n+1 ': s) t -> T (n ': s) t\ntailT0 xs = incrPos @n              #> -- 0 < n+1\n            plusMinusAssoc @n @1 @1 #> -- (n+1) - 1 = -- n+ (1 - 1)\n            slice0 @1 @(n+1) xs\n\ninitT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n =>  T (n+1 ': s) t -> T (n ': s) t\ninitT0 xs = plusMono @n @1 #> -- n <= n+1\n            slice0 @0 @n xs\n\n----------------\n-- Helpers\n\n-- | Product of a matrix of weights with a vector.\n(∙) :: (KnownNumeric t, KnownNat cols, KnownNat rows, KnownTyp t) => Tensor '[cols, rows] t -> Tensor '[cols] t -> Tensor '[rows] t\nm ∙ v = squeeze0 (matmul (expandDim0 v) m)\ninfixl 7 ∙\n\n-- | Dot product between two vectors.\n(·) :: ∀ n t. (KnownNumeric t, KnownNat n) =>\n  Tensor '[n] t -> Tensor '[n] t -> Tensor '[] t\nx · y = reduceSum0 (x ⊙ y)\ninfixl 7 ·\n\n-- | 2-Norm of a vector\nnorm :: KnownBits t => KnownNat n\n     => T '[n] (Flt t) -> Scalar (Flt t)\nnorm = frobNorm\n\n-- | 2-Norm of a tensor\nfrobNorm :: KnownShape s => KnownBits t => T s (Flt t) -> Scalar (Flt t)\nfrobNorm = sqrt . reduceSumAll . square\n\nnormalize :: (KnownNat n, KnownBits t) =>\n                   T '[n] (Flt t) -> T '[n] (Flt t)\nnormalize v = mapT (/ (norm v + epsilon)) v\n  where epsilon = 1.0e-8\n\nfillTriangular :: forall n l t.\n                  (KnownNat n, KnownNat l, KnownNumeric t, (((l+l)-n) ~ (n*n)), n <= l)\n               => Tensor '[l] t -> Tensor '[n,n] t\nfillTriangular x = plusMinusAssoc @l @l @n #> tril 0 (inflate2 (concat0 x rr))\n  where rr :: Tensor '[l - n] t\n        rr = subIneq @l @n #> slice0 @0 @(l-n) (reverseT x) \n\n\n-- @lookupManyT def indices array@ lokup indices in array, returning def if the index is -1\nlookupManyT :: forall s n t. KnownNat n => KnownShape s => (KnownNumeric t) => Scalar t -> T s Int32 -> T '[n] t -> T s t\nlookupManyT def indices array =\n  appRUnit @s #> mapTT @s (\\idx -> where_ (equal idx (-1)) def (lookupT idx array)) indices\n\n\n-- | A flexible upper-triangular matrix function: fill the upper triangle with l elements. \nfillUpperTriangular :: forall n l t. KnownNumeric t => KnownNat n => KnownNat l => T '[l] t -> T '[n,n] t\nfillUpperTriangular x =\n  zipWithTT @'[n,n]\n  (\\i j -> let idx :: Scalar Int32\n               idx = ((i * (2 * n - i - 3)) `floorDiv` 2 + j - 1)\n\n-- The index to lookup in the input array. It is computed from the formula:\n-- Output[i,j] = (j-i-1) + ∑_k^(i-1) (n-k)\n--                              \n-- The term j-i-1 is the distance from the upper diagonal.\n-- The sum is the number of elements in the previous rows\n               \n           in where_ (((j - i) `greaterThan` 0) `logicAnd` (idx `lessThan` l))\n                     (lookupT idx x)\n                     zeros)\n    range0 \n    range1 where\n\n  n, l :: Scalar Int32\n  n = constant (fromIntegral (natVal (Proxy @n)))\n  l = constant (fromIntegral (natVal (Proxy @l)))\n  \n  -- \"j\" index\n  range1 :: forall n m w. (KnownNat n, KnownNat m) => KnownBits w => T '[n,m] ('Typ 'Int w)\n  range1 = broadcastT range\n\n  -- \"i\" index\n  range0 :: forall n m w. (KnownNat n, KnownNat m) => KnownBits w => T '[n,m] ('Typ 'Int w)\n  range0 = transpose01 range1\n\n\n-------------------------\n-- Generic parameters\n\n-- | Create a parameter and initialize it with a suitable default for its type. Control the exact initializer using 'parameter'.\nparameterDefault :: forall p. ParamWithDefault p => String -> Gen p\nparameterDefault name = parameter name defaultInitializer\n\n\n-- flattenHTV :: KnownTyp t => All KnownShape xs => HTV t xs -> Tensor '[Sum (Ap (FMap CProduct) xs)] t\n-- flattenHTV Unit = zeros\n-- flattenHTV (F x :* xs) = concat0 (flattenAll x) (flattenHTV xs)\n\n-- class CProduct (xs :: [Nat])\n-- instance Fun CProduct where type Ap CProduct xs = Product xs\n\n-- inflateHTV :: ∀ xs s t. (All KnownShape xs, KnownLen s, KnownLen xs) =>\n--           Tensor '[Sum (Ap (FMap CProduct) xs)] t -> Gen (HTV t xs)\n-- inflateHTV (T x) = do\n--   v <- newVar\n--   gen (v <> text \" = \" <> funcall \"tf.split\" [x, showShape' (prodshape @xs shapeSList), text \"axis=0\"])\n--   return (mkArr @xs 0 shapeSList  v)\n--   where mkArr :: forall zs. All KnownShape zs => Int -> SList zs -> DOC -> HTV t zs\n--         mkArr _ LZ _ = Unit\n--         mkArr i (LS _ n) v = F (unsafeReshape (T (v <> brackets (int i)) )):* mkArr (succ i) n v\n--         prodshape :: forall zs. All KnownShape zs => SList zs -> [Integer]\n--         prodshape LZ = []\n--         prodshape (LS xx xs) = product (shapeToList' (shapeSListProxy xx)) : prodshape xs\n\n\n-- -- | Gradient of wrt. given parameters.\n-- grad' :: KnownLen xs => T s Float32 -> HHTV xs -> Gen (HHTV xs)\n-- grad' (T y) vars = do\n--  v <- newVar\n--  v <-- funcall \"tf.gradients\" [y, list (htoList (hmap (\\(Uncurry (T x)) -> K x) vars)) ]\n--  return (mkArr 0 shapeSList v)\n--   where mkArr :: forall xs. Int -> SList xs -> DOC -> HHTV xs\n--         mkArr _ LZ _ = Unit\n--         mkArr i (LS _ n) v = Uncurry (T (v <> brackets (int i))) :* mkArr (succ i) n v\n"
  },
  {
    "path": "TypedFlow/Types/Proofs.hs",
    "content": "{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFoldable #-}\n{-# LANGUAGE DeriveFunctor #-}\n{-# LANGUAGE DeriveTraversable #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE GADTs #-}\n{-# LANGUAGE GeneralizedNewtypeDeriving #-}\n{-# LANGUAGE InstanceSigs #-}\n{-# LANGUAGE MagicHash #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# LANGUAGE OverloadedStrings #-}\n{-# LANGUAGE PatternSynonyms #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE StandaloneDeriving #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeInType #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE UndecidableSuperClasses #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-# LANGUAGE CPP #-}\n#if __GLASGOW_HASKELL__ >= 806\n{-# LANGUAGE NoStarIsType #-}\n#endif\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n\nmodule TypedFlow.Types.Proofs where\n\n\nimport Prelude hiding (RealFrac(..))\nimport GHC.TypeLits\nimport Data.Proxy\nimport TypedFlow.Types hiding (T)\nimport Data.Type.Equality\nimport Unsafe.Coerce\nimport Data.Kind (Type)\nclass SingEq s where\n  testEq :: forall a b. s a -> s b -> Maybe (a :~: b)\n\ninstance SingEq (Sat KnownNat) where\n  testEq :: forall n m. Sat KnownNat n -> Sat KnownNat m -> Maybe (n :~: m)\n  testEq = testNatEqual\n\nnatValS :: forall m. Sat KnownNat m -> Integer\nnatValS Sat = natVal (Proxy @m)\n\ntestNatEqual :: Sat KnownNat m -> Sat KnownNat n -> Maybe (m :~: n)\ntestNatEqual m n = if natValS m == natValS n then Just (unsafeCoerce Refl) else Nothing\n\ninstance SingEq f => SingEq (NP f) where\n  testEq Unit Unit = Just Refl\n  testEq (x :* xs) (y :* ys) = case (testEq x y, testEq xs ys) of\n    (Just Refl, Just Refl) -> Just Refl\n    _ -> Nothing\n  testEq _ _ = Nothing\n\ninstance SingEq SKind where\n  testEq SBool SBool = Just Refl\n  testEq SInt SInt = Just Refl\n  testEq SFloat SFloat = Just Refl\n  testEq _ _ = Nothing\n\ninstance SingEq SNBits where\n  testEq SB32 SB32 = Just Refl\n  testEq SB64 SB64 = Just Refl\n  testEq _ _ = Nothing\n\ninstance SingEq STyp where\n  testEq (STyp k b Refl) (STyp k' b' Refl) = case (testEq k k', testEq b b') of\n    (Just Refl, Just Refl) -> Just Refl\n    _ -> Nothing\n\n-- | Use a reified equality relation\n(#>) :: (a :~: b) -> ((a ~ b) => k) -> k\nRefl #> k = k\ninfixr 0 #>\n\n-- | Use a reified arbitrary predicate\n(?>) :: Sat constraint a -> (constraint a => k) -> k\nSat ?> k = k\ninfixr 0 ?>\n\n-- | Use a reified arbitrary constraint\n(??>) :: Dict constraint -> (constraint => k) -> k\nDict ??> k = k\ninfixr 0 ??>\n\nproductS :: forall s. SShape s -> Sat KnownNat (Product s)\nproductS s = knownSShape s ?>\n             knownProduct @s ?>\n             Sat\n\nplusComm :: forall x y. (x + y) :~: (y + x)\nplusComm = unsafeCoerce Refl\n\nplusCommS :: forall x y px py. px x -> py y -> (x + y) :~: (y + x)\nplusCommS _ _ = plusComm @x @y\n\nplusAssoc :: forall x y z. (x + y) + z :~: x + (y + z)\nplusAssoc = unsafeCoerce Refl\n\nplusAssocS :: forall x y z px py pz. px x -> py y -> pz z -> ((x + y) + z) :~: (x + (y + z))\nplusAssocS _ _ _ = plusAssoc @x @y @z\n\nprodAssoc :: forall x y z. (x * y) * z :~: x * (y * z)\nprodAssoc = unsafeCoerce Refl\n\nprodAssocS :: forall x y z px py pz. px x -> py y -> pz z -> ((x * y) * z) :~: (x * (y * z))\nprodAssocS _ _ _ = prodAssoc @x @y @z\n\nprodCommS :: forall x y px py. px x -> py y -> (x * y) :~: (y * x)\nprodCommS _ _ = unsafeCoerce Refl\n\ntermCancelation :: forall a b. (a + b) - b :~: a\ntermCancelation = plusMinusAssoc @a @b @b #> cancelation @b #> Refl\n\nplusMinusAssoc :: forall x y z. (x + y) - z :~: x + (y - z)\nplusMinusAssoc = unsafeCoerce Refl\n\ncancelation :: (a - a) :~: 0\ncancelation = unsafeCoerce Refl\n\nplusMono :: forall a b. (a <=? (a+b)) :~: 'True\nplusMono = unsafeCoerce Refl\n\nsuccPos :: (1 <=? 1+j) :~: 'True\n  -- CmpNat 0 (1 + n) :~: 'LT\nsuccPos = unsafeCoerce Refl\n\nsuccPosProx2 :: forall n proxy a. proxy n a -> (0 :<: (1+n))\nsuccPosProx2 _ = succPos @n\n\nprodHomo ::  forall x y. Product (x ++ y) :~: Product x * Product y\nprodHomo = unsafeCoerce Refl\n\nprodHomoS ::  forall x y px py. px x -> py y -> ((Product (x ++ y) :~: (Product x * Product y)))\nprodHomoS _ _ = prodHomo @x @y\n\nknownProduct' :: forall s f. All KnownNat s => NP f s -> Sat KnownNat (Product s)\nknownProduct' Unit = Sat\nknownProduct' (_ :* n) = knownProduct' n ?> Sat\n\nknownProduct :: forall s. KnownShape s => Sat KnownNat (Product s)\nknownProduct = knownProduct' @s typeSList\n\nknownSumS :: forall s. NP (Sat KnownNat) s -> Sat KnownNat (Sum s)\nknownSumS Unit = Sat\nknownSumS (Sat :* n) = knownSumS n ?> Sat\n\nknownSum' :: forall s f. All KnownNat s => NP f s -> Sat KnownNat (Sum s)\nknownSum' proxies = knownSumS (allKnown' proxies)\n\nknownSum :: forall s. KnownShape s => Sat KnownNat (Sum s)\nknownSum = knownSum' @s typeSList\n\nknownPlus :: forall m n. KnownNat m => KnownNat n => Sat KnownNat (m + n)\nknownPlus = Sat\n\ntakeDrop :: forall s n. (PeanoNat n <= Length s) => (Take n s ++ Drop n s) :~: s\ntakeDrop = unsafeCoerce Refl\n\nlengthHomo :: forall x y. Length (x ++ y) :~: Length x + Length y\nlengthHomo = unsafeCoerce Refl\n\nlengthHomoS :: forall x y proxyx proxyy. proxyx x -> proxyy y -> ((Length (x ++ y) :~: (Length x + Length y)))\nlengthHomoS _ _ = lengthHomo @x @y\n\nlengthInit :: forall s. (0 < Length s) => SList s -> ((Length (Init s) + 1) :~: Length s)\nlengthInit x = lengthHomo @(Init s) @'[Last s] #> initLast x #> Refl\n\ntype a :<=: b = ((a <=? b):~: 'True)\ntype i :<: j = (i+1) :<=: j\n\nincrPos :: forall x. 1 :<=: (x+1)\nincrPos = unsafeCoerce Refl\n\n\nsubIneq :: forall x k. (x - k) :<=: x\nsubIneq = unsafeCoerce Refl\n\nincrCong :: forall x y. ((x+1) ~ (y+1)) => x :~: y\nincrCong = unsafeCoerce Refl\n\ninitLast :: forall s. {-(0 < Length s) => FIXME -} SList s -> ((Init s ++ '[Last s]) :~: s)\ninitLast Unit = error \"initLast': does not hold on empty lists\"\ninitLast ((:*) _ Unit) = Refl\ninitLast ((:*) _ ((:*) y ys)) = initLast ((:*) y ys) #> Refl\n\ninitLast' :: forall s. {-(0 < Length s) => FIXME -} KnownShape s => ((Init s ++ '[Last s]) :~: s)\ninitLast' = initLast (typeSList @s)\n\nappRUnit :: forall s. (s ++ '[]) :~: s\nappRUnit = unsafeCoerce Refl\n\nappAssoc ::  ((xs ++ ys) ++ zs) :~: (xs ++ (ys ++ zs))\nappAssoc = unsafeCoerce Refl\n\nappAssocS :: forall xs ys zs proxy1 proxy2 proxy3.\n             proxy1 xs -> proxy2 ys -> proxy3 zs -> (((xs ++ ys) ++ zs) :~: (xs ++ (ys ++ zs)))\nappAssocS _ _ _  = appAssoc @xs @ys @zs\n\n\nknownLast' :: All KnownNat s => SList s -> (KnownNat (Last s) => k) -> k\nknownLast' Unit _ = error \"knownLast: does not hold on empty lists\"\nknownLast' ((:*) _ Unit) k = k\nknownLast' ((:*) _ ((:*) y xs)) k = knownLast' ((:*) y xs) k\n\nknownLast :: forall s k. KnownShape s => (KnownNat (Last s) => k) -> k\nknownLast = knownLast' @s typeSList\n\nknownInit' :: All KnownNat s => SList s -> Sat KnownShape (Init s)\nknownInit' Unit = error \"knownLast: does not hold on empty lists\"\nknownInit' ((:*) _ Unit) = Sat\nknownInit' ((:*) _ ((:*) y xs)) = knownInit' ((:*) y xs) ?> Sat\n\nknownInit :: forall s. KnownShape s => Sat KnownShape (Init s)\nknownInit = knownInit' @s typeSList\n\nknownTail' :: forall x s k. All KnownNat s => SList (x ': s) -> (KnownShape s => k) -> k\nknownTail' ((:*) _ Unit) k = k\nknownTail' ((:*) _ ((:*) y xs)) k = knownTail' ((:*) y xs) k\n\nknownTail :: forall s x xs k. (s ~ (x ': xs), KnownShape s) => (KnownShape xs => k) -> k\nknownTail = knownTail' @x @xs typeSList\n\nknownAppendS :: forall s t pt. (All KnownNat s, KnownShape t) => SList s -> pt t -> Sat KnownShape (s ++ t)\nknownAppendS Unit _t = Sat\nknownAppendS ((:*) _ n) t = knownAppendS n t ?> Sat\n\nknownAppend :: forall s t.  (KnownShape s, KnownShape t) => Sat KnownShape (s ++ t)\nknownAppend = knownAppendS (typeSList @s) (Proxy @t)\n\n\n-- knownFmap' :: forall f xs. SList xs -> SList (Ap (FMap f) xs)\n-- knownFmap' Unit = Unit\n-- knownFmap' ((:*) x n) = (:*) Proxy (knownFmap' @f n)\n\nknownSList :: NP proxy xs -> Sat KnownLen xs\nknownSList Unit = Sat\nknownSList (_ :* n) = knownSList n ?> Sat\n\nknownSShape :: SShape xs -> Sat KnownShape xs\nknownSShape Unit = Sat\nknownSShape ((:*) Sat s) = knownSShape s ?> Sat\n\ndata DimExpr (a :: Nat) (x :: Nat) (b :: Nat) where\n  ANat :: Sat KnownNat x -> DimExpr a x (a * x)\n  (:*:) :: DimExpr a x b -> DimExpr b y c -> DimExpr a (x*y) c\n\nknownOutputDim :: forall a x b. Sat KnownNat a -> DimExpr a x b -> Sat KnownNat b\nknownOutputDim a (ANat x) = satMul a x\nknownOutputDim a (x :*: y) = knownOutputDim (knownOutputDim a x) y\n\ndimSat :: DimExpr a x b -> Sat KnownNat x\ndimSat (ANat s) = s\ndimSat (x :*: y) = dimSat x `satMul` dimSat y\n\nnormDim :: forall ws xs ys. DimExpr ws xs ys -> (ws * xs) :~: ys\nnormDim (ANat _) = Refl\nnormDim (a :*:b) = normDim a #>\n                   normDim b #>\n                   prodAssocS (Proxy @ws) (dimSat a) (dimSat b) #>\n                   Refl\n\ndata ShapeExpr (a :: Nat) (x :: Shape) (b :: Nat) where\n  Single :: DimExpr a x b -> ShapeExpr a '[x] b\n  AShape :: SShape x -> ShapeExpr a x (a * Product x)\n  (:++:) :: ShapeExpr a x b -> ShapeExpr b y c -> ShapeExpr a (x++y) c\n\ninfixr 5 :++:\ninfixr 5 *:!\ninfixr 5 !:*\n\n(!:*) :: DimExpr a x b -> ShapeExpr b xs c -> ShapeExpr a (x ': xs) c\nx !:* xs = Single x :++: xs\n\n(*:!) :: ShapeExpr a xs b -> DimExpr b x c -> ShapeExpr a (xs ++ '[x]) c\nxs *:! x = xs :++: Single x\n\nexprSShape :: forall a x b. ShapeExpr a x b -> SShape x\nexprSShape (AShape s) = s\nexprSShape (Single x) = dimSat x ?> typeSShape\nexprSShape (x :++: y) = exprSShape x .+. exprSShape y\n\nnormShape :: forall ws xs ys. ShapeExpr ws xs ys -> (ws * Product xs) :~: ys\nnormShape (Single x) = normDim x\nnormShape (AShape _) = Refl\nnormShape (l :++: r) = normShape l #>\n                       normShape r #>\n                       prodHomoS (exprSShape l) (exprSShape r) #>\n                       prodAssocS (Proxy @ws) (productS (exprSShape l)) (productS (exprSShape r)) #>\n                       Refl\n        -- r :: normShape b y ys ----> (b * y) ~ ys   (1)\n        -- l :: normShape ws x b ----> (ws * x) ~ b   (2)\n        -- subst (2) in (1): ((ws * x) * y) ~ ys\n        -- assoc: (ws * (x * y)) ~ ys\n\ndecideProductEq1 :: forall xs zs. ShapeExpr 1 xs zs -> Product xs :~: zs\ndecideProductEq1 a  = case normShape a of Refl -> Refl\n\ntype ShapeX = ShapeExpr 1\n\ndecideProductEq :: ShapeExpr 1 xs zs -> ShapeExpr 1 ys zs -> Product xs :~: Product ys\ndecideProductEq l r = case decideProductEq1 l of\n                        Refl -> case decideProductEq1 r of\n                          Refl -> Refl\n\n\nunsafePositive :: (1 <=? n) :~: 'True\nunsafePositive = unsafeCoerce Refl\n\nsucPred :: ((1 <=? n) ~ 'True) => (n - 1) + 1  :~: n\nsucPred = unsafeCoerce Refl\n\n-- data ORDEQ p a b where\n--   LT, GT :: ORDEQ a b\n--   EQ :: p -> ORDEQ a a\n\ndata NatExpr n where\n  NEVar :: Int -> NatExpr n\n  (::+) ::  NatExpr m -> NatExpr n -> NatExpr (m+n)\n  (::*) ::  NatExpr m -> NatExpr n -> NatExpr (m*n)\n\ndata NatSum n where\n  NSZero :: NatSum 0\n  NSAdd :: NatProd m -> NatSum n -> NatSum (m+n)\n\ndata NatProd n where\n  NPUnit :: Sat KnownNat k -> NatProd k\n  NPTimes :: Sat KnownNat m -> Int -> NatProd n -> NatProd (m*n)\n\nsortProd :: NatProd n -> NatProd n\nsortProd (NPUnit k) = NPUnit k\nsortProd (NPTimes x xId y) = insertProd x xId (sortProd y)\n  where insertProd :: Sat KnownNat m -> Int -> NatProd n -> NatProd (m*n)\n        insertProd x xId rest = case rest of\n          (NPUnit k) -> NPTimes x xId rest\n          (NPTimes y yId ys) -> if xId <= yId\n                                then NPTimes x xId rest\n                                else prodAssocS x y ys #>\n                                     prodCommS x y #>\n                                     prodAssocS y x ys #>\n                                     NPTimes y yId (insertProd x xId ys)\n\nsortSum :: NatSum n -> NatSum n\nsortSum NSZero = NSZero\nsortSum (NSAdd x y) = insertTerm (sortProd x) (sortSum y)\n  where insertTerm :: NatProd m -> NatSum n -> NatSum (m+n)\n        insertTerm p rest = case rest of\n          NSZero -> NSAdd p NSZero\n          NSAdd q qs -> case compareTerms p q of\n            Right p' -> plusAssocS p q qs #> NSAdd p' qs\n            Left False -> NSAdd p rest\n            Left True -> plusAssocS p q qs #>\n                         plusCommS  p q #>\n                         plusAssocS q p qs #>\n                         NSAdd q (insertTerm p qs)\n\n\ncompareTerms :: NatProd n -> NatProd m -> Either Bool (NatProd (n+m))\ncompareTerms (NPUnit Sat) (NPUnit Sat) = Right (NPUnit Sat)\ncompareTerms (NPUnit _) (NPTimes _ _ _) = Left False\ncompareTerms (NPTimes _ _ _) (NPUnit _)  = Left True\ncompareTerms (NPTimes x xId xs) (NPTimes y yId ys) =\n  case testEq x y of\n    Nothing -> Left (natValS x <= natValS y)\n    Just Refl -> case compareTerms xs ys of\n      Left x -> Left x\n      Right p -> distrLS x xs ys #> Right (NPTimes x xId p) \n\n  \n\ndistrLS :: forall a b c px py pz. px a -> py b -> pz c ->  a * (b + c) :~: (a * b + a * c)\ndistrLS = unsafeCoerce Refl\n\ndistrRS :: forall a b c px py pz. px a -> py b -> pz c ->  (a + b) * c :~: ((a * c) + (b * c))\ndistrRS = unsafeCoerce Refl\n\nexpandProd :: NatSum m -> NatSum n -> NatSum (m*n)\nexpandProd NSZero _ = NSZero\nexpandProd _ NSZero = NSZero\nexpandProd (NSAdd a b) c = distrRS a b c #> expandSum (expandP' a c) (expandProd b c)\n\nexpandP' :: NatProd m -> NatSum n -> NatSum (m*n)\nexpandP' _ NSZero = NSZero\nexpandP' p (NSAdd q a) = distrLS p q a #> NSAdd (expandPP p q) (expandP' p a)\n\nexpandPP :: NatProd m -> NatProd n -> NatProd (m*n)\nexpandPP (NPUnit k) (x) = expandKP k x\nexpandPP (NPTimes x xId y) z = prodAssocS x y z #> NPTimes x xId (expandPP y z)\n\nexpandKP :: Sat KnownNat k -> NatProd n -> NatProd (k*n)\nexpandKP Sat (NPUnit Sat) = NPUnit Sat\nexpandKP k@Sat (NPTimes x xId y)\n  = prodAssocS k x y #>\n    prodCommS k x #>\n    prodAssocS x k y #>\n    NPTimes x xId (expandKP k y) \n\nexpandSum :: NatSum m -> NatSum n -> NatSum (m+n)\nexpandSum NSZero x = x\nexpandSum (NSAdd x y) z = plusAssocS x y z #> NSAdd x (expandSum y z)\n\n\nnatRec :: forall (n :: Nat) (p :: Nat -> Type). KnownNat n => p 0 -> (forall (m :: Nat). p m -> p (m+1)) -> p n\nnatRec z s = case natVal (Proxy @n) of\n  0 -> unsafeCoerce z\n  _ -> case unsafePositive @n of\n         Refl -> case sucPred @n of\n           Refl -> s @(n-1) (natRec @(n-1) @p z s)\n\n\ndata CountRes n where\n  CountRes :: Integer -> V n Integer -> CountRes n\n\nvcount :: forall n. KnownNat n => V n Integer\nvcount =\n  case natRec @n (CountRes (natVal (Proxy @n)-1) VUnit) (\\(CountRes m xs) ->\n                                                            plusCommS (Proxy @1) (F xs)\n                                                            #> CountRes (m-1) (m :** xs)) of\n  CountRes _ x -> x\n\ndata V n a where\n  VUnit :: V 0 a\n  (:**) :: a -> V n a -> V (1+n) a\ninfixr 5 :**\n\nderiving instance (Functor (V n))\n\ninstance KnownNat n => Applicative (V n) where\n  pure x = fmap (const x) (vcount @n)\n  VUnit <*> VUnit = VUnit\n  (f :** fs) <*> (a :** as) = succPosProx2 fs #> (f a :** (fs <*> unsafeCoerce as))\n"
  },
  {
    "path": "TypedFlow/Types.hs",
    "content": "{-# LANGUAGE QuantifiedConstraints #-}\n{-# LANGUAGE CPP #-}\n#if __GLASGOW_HASKELL__ >= 806\n{-# LANGUAGE NoStarIsType #-}\n#endif\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFoldable #-}\n{-# LANGUAGE DeriveFunctor #-}\n{-# LANGUAGE DeriveTraversable #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE GADTs #-}\n{-# LANGUAGE GeneralizedNewtypeDeriving #-}\n{-# LANGUAGE InstanceSigs #-}\n{-# LANGUAGE MagicHash #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# LANGUAGE OverloadedStrings #-}\n{-# LANGUAGE PatternSynonyms #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE StandaloneDeriving #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeInType #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE UndecidableSuperClasses #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-# LANGUAGE LambdaCase #-}\n{-# LANGUAGE ApplicativeDo #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n\nmodule TypedFlow.Types where\n\nimport GHC.TypeLits \nimport Data.Proxy\n-- import Control.Monad.State\n-- import Control.Monad.RWS (RWS(..), local, ask, tell)\nimport Data.Kind (Constraint,Type)\nimport qualified Data.Int as Hask\nimport Data.Type.Equality\nimport Data.Monoid hiding (Sum,Product,Last,All,Ap)\nimport Data.Complex\n\nnewtype (∘) f (g :: k -> k2) (a::k) where\n  Comp :: forall f g a. f (g a) -> (f ∘ g) a\n\ntype Sat = (∘) Dict\ntype Sat' f x = f x\n\ndata Dict :: Constraint -> Type where\n  Dict :: a => Dict a\n\npattern Sat :: forall k (g :: k -> Constraint) (a :: k). () => g a => (∘) Dict g a -- the second context is the PROVIDED constraint!\npattern Sat = Comp Dict\n\ninstance (Show (Sat a b)) where\n  show _ = \"Sat\"\n\nproxySat :: forall k (b::k) (a :: k -> Constraint) proxy. a b => proxy b -> Sat a b\nproxySat _ = Sat\n\nnatSat :: forall n. KnownNat n => Sat KnownNat n\nnatSat = Sat\n\n-- type i < j = CmpNat i j ~ 'LT\ntype i < j = (i+1) <= j\n-- type i <= j = (i <=? j) ~ 'True\n\ntype family Product xs where\n  Product '[] = 1\n  Product (x ': xs) = x * Product xs\n\ntype family Sum xs where\n  Sum '[] = 0\n  Sum (x ': xs) = x + Sum xs\n\n\ntype family (++) xs ys where\n   '[] ++  xs       = xs\n   (x ': xs) ++ ys       = x ': (xs ++ ys)\n\ntype family Tail xs where\n  Tail (x ': xs) = xs\n\ntype family Last xs where\n  Last '[x] = x\n  Last (x ': xs) = Last xs\n\ntype family Init xs where\n  Init '[x] = '[]\n  Init (x ': xs) = x ': Init xs\n\n\n\ntype family Length xs where\n  Length '[] = 0\n  Length (x ': xs) = 1 + Length xs\n\ntype family Reverse' xs ys where\n  Reverse' '[] ys = ys\n  Reverse' (x ': xs) ys = Reverse' xs (x ': ys )\n\ntype family Reverse xs where\n  Reverse xs = Reverse' xs '[]\n\n-- From: https://www.cs.ox.ac.uk/projects/utgp/school/andres.pdf\ndata NP f (xs :: [k]) where\n  Unit :: NP f '[]\n  (:*) :: f x -> NP f xs -> NP f (x ': xs)\n\nderiving instance (forall x. Show (f x)) => Show (NP f xs)\ntype SList' = NP\n\n(.+.) = appSList\ninfixr 5 .+.\ninfixr 5 *:\ninfixr 5 :*\n\n(*:) :: forall x xs f. NP f xs -> f x -> NP f (xs ++ '[x])\nxs *: x = appSList xs (x :* Unit)\n\nhlookup :: Axis n xs -> NP f xs -> f (At n xs)\nhlookup AxZero  (x :* _) = x\nhlookup (AxSucc n) (_ :* xs) = hlookup n xs\n\nnewtype I a = I a\nnewtype K a x = K a\ntype HList = NP I\n\npattern HSingle :: f a -> NP f '[a]\npattern HSingle x = x :* Unit\n\npattern VecSing :: Tensor s t -> HTV t '[s]\npattern VecSing t1 = F t1 :* Unit\n\npattern VecPair :: Tensor s t -> Tensor s' t -> HTV t '[s,s']\npattern VecPair t1 t2 = F t1 :* F t2 :* Unit\n\npattern VecTriple :: Tensor s t -> Tensor s' t -> Tensor s3 t -> HTV t '[s,s',s3]\npattern VecTriple t1 t2 t3 = F t1 :* F t2 :* F t3 :* Unit\n\ntype family All (c :: k -> Constraint) (xs :: [k]) :: Constraint where\n  All c '[] = ()\n  All c (x ': xs) = (c x, All c xs)\n\nknownAll :: forall constraint s k. NP (Sat constraint) s -> (All constraint s => KnownLen s => k) -> k\nknownAll Unit k = k\nknownAll (Sat :* xs) k = knownAll xs $ k\n\nallKnown' :: forall constraint s proxy. All constraint s => NP proxy s -> NP (Sat constraint) s\nallKnown' Unit = Unit\nallKnown' (_ :* xs) = Sat :* allKnown' xs\n\nallKnown :: forall k s. KnownLen s => All k s => NP (Sat k) s\nallKnown = allKnown' typeSList\n\ndata SomeSuch k f where\n  SomeSuch :: k x => f x -> SomeSuch k f\n\nclass Fun (c :: k -> Constraint)  where -- FIXME: use type, not constraint?\n  type Ap c (t :: k) :: l\n\nclass Cons (x :: k) (xs :: [k])\ninstance Fun (Cons x) where type Ap (Cons x) xs = x ': xs\n\nclass Snoc (x :: k) (xs :: [k])\ninstance Fun (Snoc x) where\n  type Ap (Snoc x) '[] = '[x]\n  type Ap (Snoc x) (y ': ys) = y ': Ap (Snoc x) ys\n\nclass FMap (c :: k -> Constraint) (xs :: [k]) where\n\ninstance Fun c => Fun (FMap c)  where\n  type Ap (FMap c) '[] = '[]\n  type Ap (FMap c) (x ': xs) = Ap c x ': Ap (FMap c) xs\n\nmapFMap :: forall g f xs. (forall x. f x -> f (Ap g x)) -> NP f xs -> NP f (Ap (FMap g) xs)\nmapFMap _ Unit = Unit\nmapFMap f (x :* xs) = f x :* mapFMap @g @f f xs\n\n-- type family All2 (c :: k -> l -> Constraint) (xs :: [k]) (ys :: [l]) :: Constraint where\n--   All2 c '[] '[] = ()\n--   All2 c (x ': xs) (y ': ys) = (c x y, All2 c xs ys)\n--   All2 c '[] (y ': ys) = 'True ~ 'False\n--   All2 c (y ': ys) '[] = 'True ~ 'False\n\n-- | Flip at type level\nnewtype F g t s = F {fromF :: g s t}\n\n\n-- | Tensor vector. (Elements in the indexing list are ignored.)\ntype TV s t = NP (K (Tensor s t))\n\n-- | Heterogeneous tensor vector with varying shapes and the same kind of elements\ntype HTV t = NP (F T t)\n\nclass Scnd' (x::(a,b))\ninstance Fun (Scnd') where type Ap Scnd' '(a,b) = b\n\nclass Frst' (x::(a,b))\ninstance Fun (Frst') where type Ap Frst' '(a,b) = a\n\n\ntype family Frst (x :: (a,b)) where Frst '(x,y) = x\ntype family Scnd (x :: (a,b)) where Scnd '(x,y) = y\n\ntype family Frst3 (x :: (a,b,c)) where Frst3 '(x,y,z) = x\ntype family Scnd3 (x :: (a,b,c)) where Scnd3 '(x,y,z) = y\ntype family Thrd3 (x :: (a,b,c)) where Thrd3 '(x,y,z) = z\n\nclass (KnownShape (Scnd3 r), KnownTyp (Thrd3 r), KnownSymbol (Frst3 r)) => KnownPlaceholder r where\n  placeHolderRef :: proxy r -> Ref String (Scnd3 r) (Thrd3 r)\n\ninstance (KnownShape y, KnownTyp z, KnownSymbol x) => KnownPlaceholder '(x,y,z) where\n  placeHolderRef _ = Ref (symbolVal (Proxy @x)) typeSShape typeSTyp\nclass (KnownShape (Frst r), KnownTyp (Scnd r)) => KnownPair r\ninstance (KnownShape x, KnownTyp y) => KnownPair '(x,y)\n\n\nnewtype Uncurry g (s :: (a,b)) = Uncurry {fromUncurry :: g (Frst s) (Scnd s)}\n\n-- | Tensor vector heterogenous in types and shapes.\ntype HHTV = NP (Uncurry T)\n\ntype Placeholders = NP Placeholder\ntype PH = (Symbol,Shape,Typ)\nnewtype Placeholder (s :: PH) = PHT (T (Scnd3 s) (Thrd3 s))\n\nhhead :: NP f (x ': xs) -> f x\nhhead (x :* _) = x\n\nhtail :: NP f (x ': xs) -> NP f xs\nhtail (_ :* xs) = xs\n\nhtmap :: forall f ss t u. (forall s. Tensor s t -> Tensor (Ap f s) u) -> HTV t ss -> HTV u (Ap (FMap f) ss)\nhtmap _ Unit = Unit\nhtmap f (F x :* xs) = F (f x) :* htmap @f f xs\n\n-- htmap' :: forall f ss t u. All KnownShape ss => (forall s. KnownShape s => Tensor (Ap f s) t -> Tensor s u) -> SList ss -> HTV t (Ap (FMap f) ss) -> HTV u ss \n-- htmap' _ Unit Unit = Unit\n-- htmap' f ((:*) _ n)(F x :* xs) = F (f x) :* htmap' @f f n xs\n\n-- | Map a natural transformation\nhmap :: (forall x. f x -> g x) -> NP f xs -> NP g xs\nhmap _ Unit = Unit\nhmap f (x :* xs) = f x :* hmap f xs\n\nhTraverse :: forall f g xs m. Applicative m => (forall x. f x -> m (g x)) -> NP f xs -> m (NP g xs)\nhTraverse _ Unit = pure Unit\nhTraverse f (x :* xs) = do\n  x' <- f x\n  xs' <- hTraverse f xs\n  return (x' :* xs')\n\n-- | Variant of hmap with a constraint\nhmapK :: forall k f g xs. All k xs => (forall x. k x => f x -> g x) -> NP f xs -> NP g xs\nhmapK _ Unit = Unit\nhmapK f (x :* xs) = f x :* hmapK @k f xs\n\nhMapToList :: forall k f xs a. All k xs => (forall x. k x => f x -> a) -> NP f xs -> [a]\nhMapToList f = htoList . hmapK @k (K . f)\n\n-- | If NP is in fact a vector, we have a \"usual\" map.\nkmap :: (a -> b) -> NP (K a) xs -> NP (K b) xs\nkmap _ Unit = Unit\nkmap f (K x :* xs) = K (f x) :* kmap f xs\n\n-- | If NP is in fact a tuple, we can apply a tuple of endomorphisms. (special case of <*>)\nhendo :: NP Endo xs -> HList xs -> HList xs\nhendo Unit Unit = Unit\nhendo (Endo f :* fs) (I x :* xs) = (I (f x) :* hendo fs xs)\n\nappSList, (.+.), happ :: NP f xs -> NP f ys -> NP f (xs ++ ys)\nhapp Unit xs = xs\nhapp (x :* xs) ys = x :* (happ xs ys)\nappSList = happ\n\ndata Both f g x = Both {frst :: f x, scnd :: g x}\n\nbothFromPair :: (f x, g x) -> Both f g x\nbothFromPair (x,y) = (Both x y)\n\nbothToPair :: Both f g x -> (f x, g x)\nbothToPair (Both x y)  = (x,y)\n\n\nhzip :: NP f xs -> NP g xs -> NP (Both f g) xs\nhzip = hzipWith Both\n\nhzipWith :: (forall x. f x -> g x -> h x) -> NP f xs -> NP g xs -> NP h xs\nhzipWith _ Unit Unit = Unit\nhzipWith f (x :* xs) (y :* ys) = f x y :* hzipWith f xs ys\n\nhfor :: forall k f xs m a. All k xs => Applicative m => NP f xs -> (forall x. k x => f x -> m a) -> m [a]\nhfor Unit _  = pure []\nhfor (x :* xs) f = (:) <$> f x <*> hfor @k xs f\n\nhtoList :: NP (K a) xs -> [a]\nhtoList Unit = []\nhtoList (K x :* xs) = x : htoList xs\n\n\nhsplit' :: SPeano n -> NP f xs -> (NP f (Take n xs), NP f (Drop n xs))\nhsplit' SZero xs = (Unit,xs)\nhsplit' (SSucc _n) Unit = (Unit,Unit)\nhsplit' (SSucc n) (x :* xs) = case hsplit' n xs of\n  (l,r) -> (x :* l,r)\n\nhsplit :: forall xs ys f. KnownLen xs => NP f (xs++ys) -> (NP f xs, NP f ys)\nhsplit xys = splitApp @xs @ys (hsplit' (shapePeano @xs) xys)\n\nsplitApp' :: forall ys xs k. SList xs -> ((Take (PeanoLength xs) (xs ++ ys) ~ xs,\n                                              Drop (PeanoLength xs) (xs ++ ys) ~ ys) => k) -> k\nsplitApp' Unit k = k\nsplitApp' ((:*) _ n) k = splitApp' @ys n k\n\nsplitApp :: forall xs ys k. KnownLen xs => ((Take (PeanoLength xs) (xs ++ ys) ~ xs,\n                                             Drop (PeanoLength xs) (xs ++ ys) ~ ys) => k) -> k\nsplitApp = splitApp' @ys (typeSList @xs)\n\nhsnoc :: NP f xs -> f x -> NP f (xs ++ '[x])\nhsnoc xs x = happ xs (x :* Unit)\n\n\ndata Peano = Zero | Succ Peano -- TODO: type Peano = '[()] (And then SPeano = NP) ?\n\naxis0 :: Axis 'Zero (x ': xs)\naxis0 = AxZero\naxis1 :: Axis ('Succ 'Zero) (x0 ': (x1 ': xs))\naxis1 = AxSucc axis0\naxis2 :: Axis ('Succ ('Succ 'Zero)) (x0 ': (x1 ': (x2 ': xs)))\naxis2 = AxSucc axis1\naxis3 :: Axis ('Succ ('Succ ('Succ 'Zero))) (x0 ': (x1 ': (x2 ': (x3 ': xs))))\naxis3 = AxSucc axis2\n\n\ndata Axis n xs where\n  AxZero :: Axis 'Zero (x ': xs)\n  AxSucc :: Axis n xs -> Axis ('Succ n) (x ': xs)\n\naxisInt :: Axis n xs -> Integer\naxisInt AxZero = 0\naxisInt (AxSucc n) = 1 + axisInt n\n\nsPeanoInt :: SPeano n -> Integer\nsPeanoInt (SSucc n) = 1 + sPeanoInt n\nsPeanoInt SZero = 0\n\ntype family PeanoNat (n::Peano) :: Nat where\n  PeanoNat 'Zero = 0\n  PeanoNat ('Succ n) = PeanoNat n + 1\n\ndata SPeano n where\n  SZero :: SPeano 'Zero\n  SSucc :: SPeano n -> SPeano ('Succ n)\n\nclass KnownPeano n where\n  knownPeano :: SPeano n\n\ninstance KnownPeano 'Zero where\n  knownPeano = SZero\n\ninstance KnownPeano n => KnownPeano ('Succ n) where\n  knownPeano = SSucc knownPeano\n\ntype family Take n xs where\n   Take 'Zero xs            =  '[]\n   Take ('Succ n) '[] =  '[]\n   Take ('Succ n) (x ': xs) =  x ': Take n xs\n\ntype family Drop n xs where\n   Drop 'Zero xs            = xs\n   Drop _ '[]       = '[]\n   Drop ('Succ n) (x ': xs) = Drop n xs\n\ntype family At n xs where\n  At 'Zero (x ': xs) = x\n  At ('Succ n) (x ': xs) = At n xs\n\n-- type family Drop n xs where\n--    Drop 'Zero xs            = xs\n--    Drop _ '[]       = '[]\n--    Drop ('Succ n) (x ': xs) = Drop n xs\n\n-- type family At n xs where\n--   At 'Zero (x ': xs) = x\n--   At ('Succ n) (x ': xs) = At n xs\n\ndata Kind = Float | Cmplx | Int | Bool deriving (Show,Eq,Ord)\ndata SKind (s::Kind) where\n  SFloat :: SKind 'Float\n  SCmplx :: SKind 'Cmplx\n  SInt :: SKind 'Int\n  SBool :: SKind 'Bool\n\ndata NBits = B32 | B64 | B1 deriving (Show,Eq,Ord)\n\ndata SNBits s where\n  SB32 :: SNBits 'B32\n  SB64 :: SNBits 'B64\n\ndata Typ = Typ Kind NBits deriving (Eq,Ord)\ntype family TypKind (t :: Typ) where TypKind ('Typ k b)  = k\ntype family TypBits (t :: Typ) where TypBits ('Typ k b)  = b\n\ntype KnownNumeric t = (NumericKind (TypKind t), KnownBits (TypBits t), t ~ 'Typ (TypKind t) (TypBits t))\ntype KnownFloat t = (TypKind t ~ 'Float, KnownBits (TypBits t), t ~ 'Typ 'Float (TypBits t))\ntype KnownAlgebraic t = (AlgebraicKind (TypKind t), KnownBits (TypBits t), t ~ 'Typ (TypKind t) (TypBits t))\n\n\nclass KnownKind t => NumericKind t where\ninstance NumericKind 'Float\ninstance NumericKind 'Cmplx\ninstance NumericKind 'Int\nclass NumericKind t => AlgebraicKind t where\ninstance AlgebraicKind 'Float\ninstance AlgebraicKind 'Cmplx\n\nkVal :: SKind t1 -> Kind\nkVal SFloat = Float\nkVal SInt = Int\nkVal SBool = Bool\nkVal SCmplx = Cmplx\n\ninstance Eq (SKind t) where x == y = kVal x == kVal y\ninstance Ord (SKind t) where compare x y = compare (kVal x) (kVal y)\n\nnbitsVal :: SNBits w -> NBits\nnbitsVal SB64 = B64\nnbitsVal SB32 = B32\n\ninstance Eq (SNBits t) where x == y = nbitsVal x == nbitsVal y\ninstance Ord (SNBits t) where compare x y = compare (nbitsVal x) (nbitsVal y)\n\nsTypTyp :: STyp t1 -> Typ\nsTypTyp (STyp k b Refl) = Typ (kVal k) (nbitsVal b)\n\ninstance Eq (STyp t) where x == y = sTypTyp x == sTypTyp y\ninstance Ord (STyp t) where compare x y = compare (sTypTyp x) (sTypTyp y)\n\ndata STyp t where\n  STyp :: SKind (TypKind t) -> SNBits (TypBits t) -> (t :~: 'Typ (TypKind t) (TypBits t)) -> STyp t\n\ntype Flt t = 'Typ 'Float t\ntype Float32 = 'Typ 'Float 'B32\ntype Complex32 = 'Typ 'Cmplx 'B32\ntype Int32 = 'Typ 'Int 'B32\ntype Int64 = 'Typ 'Int 'B64\ntype TFBool = 'Typ 'Bool 'B32\ntype Scalar t = T '[] t\n\ntype Shape = [Nat]\n\nclass (KnownLen s, All KnownNat s) => KnownShape s where\n\ninstance KnownShape '[]\ninstance (KnownNat x, KnownShape xs) => KnownShape (x ': xs)\n\ntype KnownTyp t = (KnownBits (TypBits t), KnownKind (TypKind t), t ~ 'Typ (TypKind t) (TypBits t))\n\ntypeSTyp :: forall t. KnownTyp t => STyp t\ntypeSTyp = STyp (kindVal @(TypKind t)) (bitsVal @(TypBits t)) Refl\n\ntype family HaskType t where\n  HaskType Float32 = Float\n  HaskType ('Typ 'Float 'B64) = Double\n  HaskType ('Typ 'Cmplx 'B32) = Complex Float\n  HaskType ('Typ 'Cmplx 'B64) = Complex Double\n  HaskType ('Typ 'Int 'B64) = Hask.Int64\n  HaskType ('Typ 'Int 'B32) = Hask.Int32\n  HaskType ('Typ 'Bool w) = Bool\n\nclass KnownBits t where\n  bitsVal :: SNBits t\n\ninstance KnownBits 'B32 where bitsVal = SB32\ninstance KnownBits 'B64 where bitsVal = SB64\n\ntypVal :: forall t. KnownTyp t => Typ\ntypVal = Typ (kVal k) (nbitsVal b)\n  where k = kindVal @(TypKind t)\n        b = bitsVal @(TypBits t)\n\nknownBits :: SNBits t -> (KnownBits t => Fractional (HaskType ('Typ 'Float t)) => Floating (HaskType ('Typ 'Float t)) => k) -> k\nknownBits SB32 k = k\nknownBits SB64 k = k\n\nknownKind :: SKind t -> (KnownKind t => k) -> k\nknownKind SFloat k = k\nknownKind SInt k = k\nknownKind SBool k = k\nknownKind SCmplx k = k\n\nknownTyp :: STyp t -> (KnownTyp t => k) -> k\nknownTyp (STyp k b Refl) r = knownKind k $ knownBits b r\n\nknownAlgebraic :: forall t k. KnownAlgebraic t => ((Fractional (HaskType t), Floating (HaskType t)) => k) -> k\nknownAlgebraic k = case kindVal @(TypKind t) of\n  SFloat -> case bitsVal @(TypBits t) of\n    SB32 -> k\n    SB64 -> k\n  SCmplx -> case bitsVal @(TypBits t) of\n    SB32 -> k\n    SB64 -> k\n  _ -> error \"KnownAlgebraic bug\"\n\nknownNum :: forall t k. KnownNumeric t => (KnownTyp t => Num (HaskType t) => k) -> k\nknownNum k = case kindVal @(TypKind t) of\n  SFloat -> case bitsVal @(TypBits t) of\n    SB32 -> k\n    SB64 -> k\n  SCmplx -> case bitsVal @(TypBits t) of\n    SB32 -> k\n    SB64 -> k\n  SBool -> error \"KnownNumeric bug\"\n  SInt -> case bitsVal @(TypBits t) of\n    SB32 -> k\n    SB64 -> k\n\nclass KnownKind t where kindVal :: SKind t\ninstance KnownKind 'Bool where kindVal = SBool\ninstance KnownKind 'Cmplx where kindVal = SCmplx\ninstance KnownKind 'Float where kindVal = SFloat\ninstance KnownKind 'Int where kindVal = SInt\n\ntype SList = NP Proxy\n\ninstance Ord (Sat KnownNat t) where\n  compare x@Sat y@Sat = compare (natVal x) (natVal y)\n\ninstance Eq (Sat KnownNat t) where\n   x@Sat == y@Sat = (natVal x) == (natVal y)\n\ntype SShape = NP (Sat KnownNat)\n\ninstance Ord (SShape s) where\n  compare x y = compare (shapeToList' x) (shapeToList' y)\n\ninstance Eq (SShape s) where\n  Unit == Unit = True\n  ((:*) x xs) == ((:*) y ys) = x == y && xs == ys\n\n\ninstance {-# OVERLAPPING #-} Show (SShape s) where\n  show x = show (shapeToList' x)\n\n\nsListLength :: NP f s -> Integer\nsListLength Unit = 0\nsListLength ((:*) _ s) = 1+sListLength s\n\nsListLen :: NP f s -> Int\nsListLen = fromIntegral . sListLength\n\nsListLenAsNat :: NP f s -> Sat KnownNat (Length s)\nsListLenAsNat Unit = Sat\nsListLenAsNat ((:*) _ s) = case sListLenAsNat s of\n  Sat -> Sat\n\ntype family PeanoLength xs :: Peano where\n  PeanoLength '[] = 'Zero\n  PeanoLength (x ': xs) = 'Succ (PeanoLength xs)\n\n\nwithKnownNat :: forall k. Int -> (forall (n::Nat). KnownNat n => Proxy n -> k) -> k\nwithKnownNat 0 f = f (Proxy @0)\nwithKnownNat 1 f = f (Proxy @1)\nwithKnownNat n f = withKnownNat (n `div` 2) (if n `mod` 2 == 0 then f2x else f2x1)\n  where f2x,f2x1 :: forall (n::Nat). KnownNat n => Proxy n -> k\n        f2x  _ = f (Proxy @(n*2))\n        f2x1 _ = f (Proxy @(n*2+1))\n\n-- Probably a GHC bug:\n-- withKnownNat'' :: forall k. Int -> (forall (n::Nat). KnownNat n => k) -> k\n-- withKnownNat'' 0 f = f @0\n-- withKnownNat'' n f = withKnownNat'' (n-1) fsucc\n--   where fsucc :: forall (n::Nat). KnownNat n =>  k\n--         fsucc = f @(n+1)\n\n-- This also fails:\n-- appProxy :: forall (n::Nat) k. KnownNat n => Proxy n -> (forall (m::Nat). KnownNat m => k) -> k\n-- appProxy f _ = f @n\n\n-- withKnownNat :: forall k. Int -> (forall (n::Nat). KnownNat n => k) -> k\n-- withKnownNat n f = withKnownNat' n (\\proxy -> appProxy proxy f)\n\nclass KnownNat (Length s) => KnownLen s where\n  shapePeano :: SPeano (PeanoLength s)\n  typeSList :: SList s\n\ninstance KnownLen '[] where\n  shapePeano = SZero\n  typeSList = Unit\n\ninstance KnownLen xs => KnownLen (x ': xs) where\n  shapePeano = SSucc (shapePeano @xs)\n  typeSList = (:*) Proxy (typeSList @xs)\n\nlistTypeLen :: forall xs. KnownLen xs => Integer\nlistTypeLen = sListLength (typeSList @xs)\n\ntypeSListProxy :: KnownLen xs => proxy xs -> SList xs\ntypeSListProxy _ = typeSList\n\nsListProxy :: NP f xs -> Proxy xs\nsListProxy _ = Proxy\n\nknownNatVal :: forall x. Sat KnownNat x -> Integer\nknownNatVal Sat = natVal (Proxy @x)\n\nshapeToList' :: SShape s -> [Integer]\nshapeToList' Unit = []\nshapeToList' ((:*) x xs) = knownNatVal x : shapeToList' xs\n\nshapeToList'' :: All KnownNat s => NP proxy s -> [Integer]\nshapeToList'' Unit = []\nshapeToList'' ((:*) x xs) = natVal x : shapeToList'' xs\n\nshapeToList :: ∀(s::Shape). KnownShape s => [Integer]\nshapeToList = shapeToList'' (typeSList @ s)\n\ntypeSShape :: forall s. KnownShape s => SShape s\ntypeSShape = sListSShape (typeSList @s)\n\nproxySShape :: forall s. KnownShape s => Proxy s -> SShape s\nproxySShape _ = typeSShape\n\nsListSShape :: forall s. All KnownNat s => SList s -> SShape s\nsListSShape = allKnown'\n\n\n\ntype None = 514229 --  fibonnaci prime.\n-- type None = 0 - 1 -- GHC does not like negative Nats.\n-- Using a maybe type would be a RPITA.\n\n\n--------------------------------\n-- Generation Effects (TODO: move to other module)\n\ndata VarInfo = forall s t. (KnownShape s, KnownTyp t) => VarInfo {varTrainable :: Bool,\n                                    varRef :: Ref String s t,\n                                    varInitial :: Maybe (T s t)} \n\nvarName :: VarInfo -> String\nvarName VarInfo {varRef =  Ref {..}} = refName\n\ndata GState = GState {nextVar :: Integer, -- ^ next free variable\n                      genRegularizers :: [Scalar Float32] -- ^ accumulated regularizers\n                     }\ninitialGstate :: GState\ninitialGstate = (GState {nextVar = 0\n                        ,genRegularizers=[]\n                        })\n\n\n\n  \ndata Gen a where\n  GPId :: Gen Integer\n  GPVariable :: forall (shape :: Shape) t. (KnownTyp t,KnownShape shape) => Bool -> String -> Maybe (T shape t) -> Gen (Ref String shape t) \n  GPModify :: (KnownShape s,KnownTyp t) => Ref Int s t -> T s t -> Gen (T s t)\n  GPReturn :: a -> Gen a\n  GPState :: (GState -> (a,GState)) -> Gen a\n  GPApp :: (Gen (a -> b)) -> Gen a -> Gen b\n  GPBind :: Gen a -> (a -> Gen b) -> Gen b\n\n\ngenGets :: (GState -> a) -> Gen a\ngenGets f = GPState  (\\s -> (f s, s))\n\ninstance Applicative Gen where\n  (<*>) = GPApp\n  pure = GPReturn\n\ninstance Monad Gen where\n  (>>=) = GPBind\n\ninstance Functor Gen where\n  fmap f = (pure f <*>)\n\n\n--------------------------\n-- Tensors\n\n-- | An indexing tensor in the format expected by GatherND\ntype IndexTensor indexShape containerShape w = T (indexShape ++ '[Length containerShape]) ('Typ 'Int w)\n\n-- | Description of a random distribution\ndata Distribution (s :: Shape) (t :: Typ) where\n  -- | Each element is from a truncated normal distribution with given standard dev.\n  TruncatedNormalD :: Float ->  Distribution s ('Typ 'Float w)\n  -- | Each element is from a uniform distribution with given bounds (low, high)\n  UniformD :: Float -> Float -> Distribution s ('Typ 'Float w)\n  OrthogonalD  :: Distribution '[m,n] ('Typ 'Float w)\n\ndata Ref r s t = Ref {refName :: r,\n                      refShape :: SShape s,\n                      refTyp :: STyp t}\n\ndata NilOp s t where\n  ExternalVar :: Ref String s t -> NilOp s t\n  Variable :: Ref Int s t -> NilOp s t\n  Constant :: HaskType t -> NilOp '[] t\n  Range :: KnownBits w => Sat KnownNat n -> NilOp '[n] ('Typ 'Int w)\n\ndata Catable s1 s2 t n = Catable (Sat KnownNat n) (T (s1 ++ (n ': s2)) t)\n  -- deriving Show\n\n\ntype Unique = Int\n\ndata T (s :: Shape) (t :: Typ) where\n  BroadcastT :: KnownTyp t => Maybe Unique -> Bool -> Sat KnownNat n -> SShape s -> T s t ->  T (n ': s) t\n  MapT :: KnownTyp t => Sat KnownNat n -> SShape s -> (T s t -> T r u) ->  T (n ': s) t -> T (n ': r) u\n  ZipT :: (KnownTyp t, KnownTyp u) => Sat KnownNat n -> SShape s -> SShape r -> (T s t -> T r u -> T q v) ->  T (n ': s) t -> T (n ': r) u -> T (n ': q) v\n  Zip3T :: (KnownTyp t, KnownTyp u, KnownTyp v) => Sat KnownNat n -> SShape s -> SShape r -> SShape q -> (T s t -> T r u -> T q v -> T p w) ->  T (n ': s) t -> T (n ': r) u -> T (n ': q) v -> T (n ': p) w\n  T :: NilOp s t -> T s t\n  Noise :: Integer -> -- this is the unique noise identifier, preventing two different noises to ever be re-shared.\n           SShape s0 -> SShape s1 ->\n           Distribution s1 t ->\n           T (s0 ++ s1) t\n  BinOp :: (KnownTyp t,KnownTyp u) => BinOp s1 t s2 u s3 v -> SShape s0 -> SShape s1 -> STyp t -> SShape s2 -> STyp u -> T (s0 ++ s1) t -> T (s0 ++ s2) u -> T (s0 ++ s3) v\n  UnOp :: KnownTyp t => UnOp s1 t s2 u -> SShape s0 -> T (s0 ++ s1) t -> T (s0 ++ s2) u\n  Unbroadcast :: Sat KnownNat n -> Unique -> T (n ': s) t -> T s t\n  DirectBroadcast :: SShape s0 -> NP proxy' s1 -> SShape s2 -> NP proxy' s3 -> T (s0 ++ s2) t -> T (s0 ++ (s1 ++ (s2 ++ s3))) t\n  ReshapeFrom :: Product s ~ Product s0 => SShape s0 -> T s0 t -> T s t\n  Transpose :: SShape s0 -> Permutation s0 s -> T s0 t -> T s t\n  Concat :: SShape s0 -> SShape s1 -> NP (Catable s0 s1 t) ns -> T (s0 ++ (Sum ns ': s1)) t\n  Gather :: KnownTyp ('Typ 'Int w) => SShape indexShape -> SShape s0 -> Sat KnownNat m -> SShape s1\n    -> T (s0 ++ (m ': s1)) t -> T (s0 ++ indexShape) ('Typ 'Int w) -> T (s0 ++ indexShape ++ s1) t\n  GatherND :: KnownTyp ('Typ 'Int w) => SShape containerShape -> SShape elementShape -> SShape indexShape\n    -> T (containerShape ++ elementShape) t -> IndexTensor indexShape containerShape w -> T (indexShape ++ elementShape) t\n\n  MatMul :: forall s m n o t. KnownNumeric t => SShape s -> Sat KnownNat n -> Sat KnownNat  o -> Sat KnownNat m -> T (s ++ '[n,o]) t -> T (s ++ [o,m]) t -> T (s ++ [n,m]) t\n  Where :: T s TFBool  -> T s t -> T s t -> T s t\n  If :: Scalar TFBool -> T s t -> T s t -> T s t\n  Convolution :: KnownAlgebraic t => Sat KnownNat bs -> Sat KnownNat inChannels -> Sat KnownNat outChannels -> SShape filterSpatialShape -> SShape s\n            -> T (bs ': s ++ '[inChannels]) t -- input tensor (batched)\n            -> T (filterSpatialShape ++ '[inChannels,outChannels]) t -- filters\n            -> T (bs ': s ++ '[outChannels]) t\n  Pool :: Length outSpatial ~ Length window =>\n          Sat KnownNat bs -> SShape window -> PoolingType -> Sat KnownNat numChannels -> SShape outSpatial\n            -> T (bs ': ZipWithMulShapes window outSpatial ++ '[numChannels]) t\n            -> T (bs ': outSpatial ++ '[numChannels]) t\n  Softmax :: Sat KnownNat bs -> Sat KnownNat n -> T '[bs,n] (Flt w) -> T '[bs,n] (Flt w)\n    -- yes, softmax is shape-fixed: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/softmax\n\n-- instance Show Unique where\n--   show _ = \"<Unique>\"\n\n-- deriving instance (Show (T s t))\n\ntype family ZipWithMulShapes (xs::Shape) (xy::Shape) :: Shape\ntype instance ZipWithMulShapes (x ': xs) (y ': ys) = x*y ': ZipWithMulShapes xs ys\ntype instance ZipWithMulShapes '[] _ = '[]\ntype instance ZipWithMulShapes _ '[] = '[]\n\nsatMul :: forall n m. Sat KnownNat n -> Sat KnownNat m -> Sat KnownNat (n*m)\nsatMul Sat Sat = Sat\n\nsatProd :: SShape s -> Sat KnownNat (Product s)\nsatProd Unit = natSat @1\nsatProd (x :* xs) = satMul x (satProd xs)\n\nsatAdd :: forall n m. Sat KnownNat n -> Sat KnownNat m -> Sat KnownNat (n+m)\nsatAdd Sat Sat = Sat\n\nzipWithMulSShapes :: SShape xs -> SShape ys -> SShape (ZipWithMulShapes xs ys)\nzipWithMulSShapes Unit _ = Unit\nzipWithMulSShapes _ Unit = Unit\nzipWithMulSShapes ((:*) x xs) ((:*) y ys) = (:*) (satMul x y) (zipWithMulSShapes xs ys)\n\ndata PoolingType = MaxPool | AvgPool deriving Show\n\ntype Tensor shape = T shape\n\ndata ReduceOp = Mean | Max | Min | Sum\ndata Axis1Op s1 t s2 u where\n  ReverseT :: Sat KnownNat n ->  Axis1Op '[n] t '[n] t \n  ArgMax :: KnownNumeric t => Sat KnownNat n ->  Axis1Op '[n] t '[] ('Typ 'Int b)\n  OneHot :: KnownNumeric t => Sat KnownNat n ->  Axis1Op '[] ('Typ 'Int b) '[n] t\n  ReduceOp :: KnownNumeric t => Sat KnownNat n ->  ReduceOp -> Axis1Op '[n] t '[] t\n  SliceOp :: forall m n t proxy. proxy m -> Sat KnownNat n -> Integer -> Integer -> Axis1Op '[n] t '[m] t\n  AccessOp :: forall n t. Sat KnownNat n -> Integer -> Axis1Op '[n] t '[] t\n\ndata Float1Op\n  = ClipByValue Float Float\n  | Tanh\n  | Sin\n  | Exp\n  | Sigmoid\n  | HardSigmoid\n  | Relu\n  | Floor\n  | Round\n  | Cos\n  | Log\n  | Asin\n  | Acos\n  | Sinh\n  | Cosh\n  | Asinh\n  | Acosh\n  | Atan\n  | Atanh\n  | Sqrt\n  deriving Show\ndata Num1Op = Square | Negate | Abs | Sign\n  deriving Show\n\ndata Side = Upper | Lower\n\ndata UnOp (s1 :: Shape) (t :: Typ) (s2 :: Shape) (u :: Typ) where\n  ZeroTriangle :: KnownNumeric t => Sat KnownNat n -> Side -> Integer -> UnOp '[n,n] t '[n,n] t -- https://numpy.org/doc/1.16/reference/generated/numpy.tril.html\n  ExpM :: KnownNumeric t => Sat KnownNat n -> UnOp '[n,n] t '[n,n] t\n  Diag :: Sat KnownNat n -> UnOp '[n] t '[n,n] t\n  StopGradient :: UnOp '[] t '[] t\n  Cast :: UnOp '[] t '[] u\n  Conjugate :: UnOp '[] ('Typ 'Cmplx w) '[] ('Typ 'Cmplx w)\n  RealPart  :: UnOp '[] ('Typ 'Cmplx w) '[] ('Typ 'Float w)\n  Num1Op :: KnownNumeric t => Num1Op -> UnOp '[] t '[] t\n  Float1Op :: Float1Op -> UnOp '[] (Flt w) '[] (Flt w)\n  Axis1Op :: SShape s -> Axis1Op s1 t s2 u -> UnOp (s1 ++ s) t (s2 ++ s) u\n             -- deriving Show\n\ndata CompOp = Less | Greater | LessOrEqual | GreaterOrEqual\ndata LogicOp = And | Or\n\ndata Simple2Op t u where\n  Divide :: KnownAlgebraic t => Simple2Op t t\n  IntegerDiv :: Simple2Op ('Typ 'Int w) ('Typ 'Int w)\n  Equal :: KnownTyp t => Simple2Op t TFBool\n  Subtract :: KnownNumeric t => Simple2Op t t\n  Multiply :: KnownNumeric t => Simple2Op t t\n  Add :: KnownNumeric t => Simple2Op t t\n  Minimum :: KnownNumeric t => Simple2Op t t\n  Maximum :: KnownNumeric t => Simple2Op t t\n  FloorMod :: KnownNumeric t => Simple2Op t t\n  Comparision :: KnownNumeric t => CompOp -> Simple2Op t TFBool\n  Logic :: LogicOp -> Simple2Op TFBool TFBool\n  MkComplex :: Simple2Op (Flt w) ('Typ 'Cmplx w)\n\n-- deriving instance Show (Simple2Op t u)\n\ndata BinOp s1 t1 s2 t2 s3 t3 where\n  Simple2Op :: Simple2Op t u -> BinOp '[] t '[] t '[] u\n  SigmoidCrossEntropyWithLogits :: KnownFloat t => BinOp '[] t '[] t '[] t\n  SoftmaxCrossEntropyWithLogits :: KnownFloat t => BinOp '[n] t '[n] t '[] t\n  SparseSoftmaxCrossEntropyWithLogits :: BinOp '[] Int32 '[n] (Flt w) '[] (Flt w)\n\n-- deriving instance Show (BinOp a b c d e f)\n\ndata Permutation (s :: [k]) (t :: [k]) where\n  PermId :: Permutation s s\n  PermSkip :: Permutation s t -> Permutation (n ': s) (n ': t)\n  PermSwap :: Permutation (n ': m ': s) (m ': n ': s)\n  PermTrans :: Permutation s t -> Permutation t u -> Permutation s u\n\nderiving instance Show (Permutation s t)\n\nclass KnownTensors p where -- TODO: delete\n  -- | traverse all the tensors contained in p.\n  travTensor :: Applicative m => (forall s t. (KnownTyp t, KnownShape s) => String -> (T s t) -> m (T s t)) -> String -> p -> m p\n\ninstance (KnownTyp t, KnownShape shape) => KnownTensors (T shape t) where\n  travTensor f = f\n\ninstance (All KnownPair ys) => KnownTensors (HHTV ys) where\n  travTensor :: forall m. Applicative m => (forall s t'. (KnownTyp t', KnownShape s) => String -> T s t' -> m (T s t')) -> String -> HHTV ys -> m (HHTV ys)\n  travTensor f s = ttr 0\n    where ttr :: forall xs. All KnownPair xs => Int -> HHTV xs -> m (HHTV xs)\n          ttr _ Unit = pure Unit\n          ttr n (Uncurry x :* xs) = do\n            x' <- f (s <> \"_\" <> show n) x\n            xs' <- ttr (n + 1) xs\n            return (Uncurry x' :* xs')\n\ninstance (KnownTyp t, All KnownShape ys) => KnownTensors (HTV t ys) where\n  travTensor :: forall m. Applicative m => (forall s t'. (KnownTyp t', KnownShape s) => String -> T s t' -> m (T s t')) -> String -> (HTV t ys) -> m (HTV t ys)\n  travTensor f s = ttr 0\n    where ttr :: forall xs. All KnownShape xs => Int -> HTV t xs -> m (HTV t xs)\n          ttr _ Unit = pure Unit\n          ttr n (F x :* xs) = do\n            x' <- f (s <> \"_\" <> show n) x\n            xs' <- ttr (n + 1) xs\n            return (F x' :* xs')\n\ninstance (KnownTensors p, KnownTensors q) => KnownTensors (p,q) where\n  travTensor f s (x,y) = (,) <$> travTensor f (s<>\"_fst\") x <*> travTensor f (s<>\"_snd\") y\n\ninstance (KnownTensors p1, KnownTensors p2, KnownTensors p3) => KnownTensors (p1,p2,p3) where\n  travTensor f s (x,y,z) = (,,) <$> travTensor f (s<>\"_1\") x <*> travTensor f (s<>\"_2\") y <*> travTensor f (s<>\"_3\") z\n\ninstance (KnownTensors p1, KnownTensors p2, KnownTensors p3, KnownTensors p4) => KnownTensors (p1,p2,p3,p4) where\n  travTensor f s (x,y,z,w) = (,,,) <$> travTensor f (s<>\"_1\") x <*> travTensor f (s<>\"_2\") y <*> travTensor f (s<>\"_3\") z <*> travTensor f (s<>\"_4\") w\n\nclass KnownTensors p => ParamWithDefault p where\n  defaultInitializer :: Gen p\n\n\n"
  },
  {
    "path": "TypedFlow.hs",
    "content": "{-|\nModule      : TypedFlow\nDescription : Higher-Order Typed Binding to TensorFlow and Deep Learning Library\nCopyright   : (c) Jean-Philippe Bernardy, 2017\nLicense     : LGPL-3\nMaintainer  : jean-philippe.bernardy@gu.se\nStability   : experimental\n\nThis module re-exports all functions.\n-}\n\nmodule TypedFlow\n  (module TypedFlow.Types\n  ,module TypedFlow.TF\n  ,module  TypedFlow.Layers\n  ,module  TypedFlow.Learn\n  ,module GHC.TypeLits) where\n\nimport TypedFlow.TF\nimport TypedFlow.Types\nimport TypedFlow.Layers\nimport TypedFlow.Learn\nimport GHC.TypeLits\n\n"
  },
  {
    "path": "cabal.project",
    "content": "packages:\n  ./typedflow.cabal\n"
  },
  {
    "path": "docs/HOT.org",
    "content": "#+TITLE: TypedFlow: The HOT parts\n#+AUTHOR: Jean-Philippe Bernardy, University of Gothenburg\n\nTensorFlow™ is an open source software library for numerical\ncomputation using data flow graphs. Nodes in the graph represent\nmathematical operations, while the graph edges represent the\nmultidimensional data arrays (tensors) communicated between them.\nTensorFlow graphs can be efficiently evaluated on GPUs and is a\npopular choice to implement deep learning applications.\n\nTypedFlow is a higher-order and typed (HOT) frontend to tensorflow,\nwritten in (Glasgow) Haskell.\n\n In this talk I will:\n   - briefly explain what TensorFlow is and how it applies to deep learning\n   - recall the advantages of a HOT approach\n   - expose some example programs written using TypedFlow\n   - demonstrate how tensorflow functions can be given precise types,\n     using GHC extensions\n   - discuss some the difficulties of doing so\n\n* Machine learning in 45 seconds:\n\n- a vector of training inputs X :: [A]\n- a model f : (Θ × A) → ℝ⁺\n\nTask:\n\nGiven X, find θ such that f(θ,x) < ε, if\nx is considered similar to points in X.\n\nCommentary: Every point in X lie on a manyfold. We want to find what\nthis manyfold is. (Interpolation problem.)\n\n* \"Deep\" learning in 45 seconds\n\n- \"Deep\" ≡ f is \"complex\"\n- So we must use a brute force method to compute θ: stochastic\n  gradient descent (or variants thereof).\n\n- Typically, compute the gradient of f wrt. θ using AD.\n\n* Tensorflow\n\n- A (meta) programming language to define f. AD is builtin. (there is fineprint)\n\n- Restricted control flow (mostly, tensor-generalisations of +, *, -,\n  /, ^)\n\n- Typically programmed using python (standard in scientific computing)\n\n- \"Strongly typed\".\n  - but: \"brodcasting\"\n  - but: running the metaprogram can be quite slow ~1 minute (so type\n    errors can happen after 1 minute of loading the model --- and\n    programs can do other things before ...)\n  - but: types are typically not written as such. Given any two\n    functions, weather they compose (and what the composition does) is\n    a mistery unless one examines their code.\n\n* First-order culture\n\nhttps://github.com/fchollet/keras/blob/master/keras/layers/recurrent.py\n(search \"class LSTM\")\n\n* TypedFlow\n\n- An typed, higher-order frontend to tensorflow (basic tensor operations)\n- A library to construct neural networks\n- Generates python (yikes)\n\n* Typing tensors\n\n- tanh\n- matmul\n- concatT\n- repeatT\n- tile\n- convolution\n\n* Heterogeneous tensors\n\ntype HTV\n\n* Example Higher-order stuff\n\n- mapT\n- rnn\n- withBypass\n- Attention model\n\n* Complete examples\n\n- mnist\n- seq2seq\n\n* GHC woes\n\n- see transposeV\n\n* Summary\n\n- Some NN building blocks are naturally higher-order. Taking an\n  example (and simplifying) a recurrent neural network turns a tensor\n  function into a function between lists (vectorslists) of tensors.\n\n- Functional programming is ideally suited to program complicated\n  applications from building blocks.\n\n  Example: an \"Attention-model\" is a thing where every step in a RNN adds\n  a computation which depends on an external input. We can compose\n  usual RNN cells with attention models in several ways. The state of\n  the art is to reprogram all combinations by hand.\n\n- Typed APIs.\n\n  Types can be used to check the tensor dimensions. Types catch a lof\n  of errors, but they can also be used to *guide* the programming.\n\n  Types are pretty much a necessity in the presence of HO\n  functions.\n\n- TypedFlow is typically much closer to mathematical notation than\n  python. Programs are short to write and easier to read. Standard\n  building blocks can be swapped for custom versions quite easily.\n\n  Examples\n    - rnn stacking using \"residual connections\" instead of just\n      stacking.\n    - make it easy to share parameters between different components\n      (example: if we do a style translation we may want to share the\n      embedding layers between encoder and decoders parts)\n\n- Long game: integrate cutting edge ideas as they arrive with moderate\n  effort.\n\n* FAQ\n- Why not Agda, Idris?  A long term plan is to bypass python, so we'd\n  want a \"real\" programming language for the programming bits that go\n  around the TF program.\n"
  },
  {
    "path": "docs/Talk.org",
    "content": "#+TITLE: TypedFlow: A library for higher-order typed deep learning\n#+AUTHOR: Jean-Philippe Bernardy, University of Gothenburg\n\nTensorFlow is a library for numerical computation, with specific\nfeatures for machine-learning such as gradient computation. It is\nperhaps the most popular backend for deep learning applications.\nTypedFlow a higher-order and typed (HOT) frontend to TensorFlow\nwritten in Haskell, and a library of neural-network layers and\ncombinators.\n\n\nIn this talk I will:\n\n- briefly recall what TensorFlow is and how it applies to deep\n  learning\n- discuss the advantages of a HOT approach vs. plain TensorFlow\n- expose two use-cases: the standard MNIST example and a\nsequence-to-sequence network with attention model.\n\n\nIdeas: transparency, explainability\n\n\n\n* Machine learning in 45 seconds:\n\n- a vector of training inputs X::[A]\n- a model f : (Θ × A) → ℝ⁺\n\nTask:\n\nGiven X, find θ such that f(θ,x) < ε, if\nx is considered similar to points in X, and > ε otherwise.\n\nCommentary: Every point in X lie on a manifold. We want to find what\nthis manyfold is. (Interpolation problem.)\n\n* \"Deep\" learning in 45 seconds\n\n- \"Deep\" ≡ f is \"complicated\"\n- So we must use a brute force method to compute θ: stochastic\n  gradient descent (or variants thereof).\n\n- Typically, compute the gradient of f wrt. θ using AD.\n\n* Tensorflow\n\n- A (meta) programming language to define f. AD is builtin. (there is\n  fineprint)\n\n- Restricted control flow (mostly, tensor-generalisations of +, *, -,\n  /, ^, tanh, ...)\n\n- Typically programmed using python (standard in scientific computing)\n\n- \"Strongly typed\"\n  - but: no abstraction over dimensions\n  - but: \"brodcasting\"\n  - but: running the metaprogram can be quite slow ~1 minute (so type\n    errors can happen after 1 minute of loading the model --- and\n    programs can do other things before ...)\n  - but: types are typically not written as such. Given any two\n    functions, weather they compose (and what the composition does) is\n    a mistery unless one examines their code.\n\n- \"map\" has a surprising semantics (see below)\n\n* What is TypedFlow?\n\n- An typed, higher-order frontend to tensorflow\n  (basic tensor operations)\n- A library to construct neural networks\n- Generates python\n\n* Why TypedFlow?\n\nFunctional programming is ideally suited to program complicated\napplications from building blocks.\n\n- Notation\n- Types\n- HO\n\n* Deep Learning: The state of the art\n\n[[file:cards.jpg]]\n \n(Actually this has become worse!)\n* Notation\n\nHaskell is typically much closer to mathematical notation than\npython. Programs are short to write and easier to read.\n\nfile:../TypedFlow/TF.hs::/⊕.*::/\n\n* Why Types?\n\nTypes can be used to check the tensor dimensions.\n\n- Types catch a lof of errors\n- but they can also be used to *guide* the programming. \"Type holes\"\n  (see MNIST example)\n\nTypes are pretty much a necessity to take advantage of HO functions.\n\n#+BEGIN_QUOTE\nTogether with the absence of side effects, rich type systems enable to\nconstruct complex programs with a high degree of confidence:\n\n- types precisely abstract the intention of the programmer for each function,\n  without any hidden side effect, and\n- provided that they match the contracts imposed by types, functions\n  can be freely combined, using lazy evaluation and higher-order\n  facilities, without risk of pernicious interference.\n#+END_QUOTE\n\n* Python, aka The Culture of First Order\n\n[[file:imperiallegion.jpg]]\n\nhttps://github.com/fchollet/keras/blob/master/keras/layers/recurrent.py\n(search \"class LSTM\")\n\n* Example 1: LSTM\n\nfile:../TypedFlow/Layers/RNN.hs::/^lstm.*::/\n\n* Example 2: Attention\n\nExample: an \"Attention-model\" is a model where every step in a RNN\nadds a computation which depends on an external input. We can compose\nusual RNN cells with attention models in several ways. The state of\nthe art is to reprogram such combinations by hand.\n\nfile:../TypedFlow/Layers/RNN.hs::/^attentiveWithFeedback.*::/\n\n* Mapping tensors\n\n- Tensorflow's ~map~ spawns processes. This is (usually) quite a bad\n  idea --- tensor operations are parallelized anyway (but not on\n  several GPUs... the purpose of ~map~ apparently).\n\n- Most (but not all!) operations have so-called \"broadcast semantics\";\n  they can be (implicitly!) raised to tensors of higher dimensions.\n\n- file:../TypedFlow/Abstract.hs::/^protoBroadcast.*::/\n\n  - Note \"gather\" goes to \"gather_nd\"\n  - Certain convolutions can't be broadcasted at all 😿\n\n* Pretending that tensor operations are functional\n\n- They are EXCEPT that sharing is lost\n- Use the old trick of observable sharing. (Memoizing, etc.)\n\n* Long game\n\n- Integrate cutting edge DL ideas as they arrive with moderate effort.\n\n* MNIST\n\nfile:../examples/mnist/MNIST.hs\n\n* Seq2Seq\n\nfile:../examples/seq2seq/Seq2Seq.hs\n"
  },
  {
    "path": "examples/agreement/Aggr.hs",
    "content": "{-# LANGUAGE ApplicativeDo #-}\n{-# LANGUAGE ViewPatterns #-}\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UnicodeSyntax #-}\n\n\nimport TypedFlow\nimport TypedFlow.Python\nimport qualified GHC.Int as GHC\n\nonFST :: (Tensor s1 t -> Tensor s t) -> HTV t '[s1, s'] -> HTV t '[s, s']\nonFST f (VecPair h c) = (VecPair (f h) c)\n\nmkLSTM :: ∀ n x. KnownNat x => KnownNat n => \n        String -> DropProb -> Gen (RnnCell Float32 '[ '[n], '[n]] (Tensor '[x] Float32) (Tensor '[n] Float32))\nmkLSTM pName dropProb = do\n  params <- parameterDefault pName\n  drp1 <- mkDropout dropProb\n  rdrp1 <- mkDropout dropProb\n  return (timeDistribute drp1 .-. onStates (onFST rdrp1) (lstm params))\n\nmodel :: forall (vocSize::Nat) (len::Nat). KnownNat len => KnownNat vocSize =>\n   Gen (T '[len] Int32 -> T '[len] Int32 -> ModelOutput Float32 '[len,vocSize] '[])\nmodel = do\n  embs <- parameterDefault \"embs\"\n  let dropProb = DropProb 0.10\n  lstm1 <- mkLSTM @160 \"w1\" dropProb\n  drp <- mkDropout dropProb\n  w <- parameterDefault \"dense\"\n  return $ \\input gold -> do\n    let masks = constant 1 ⊝ cast @Float32 (equal (constant padding) input)\n        (_sFi,predictions) =\n          simpleRnn (timeDistribute (embedding @12 @vocSize embs) .-.\n            lstm1 .-.\n            timeDistribute drp .-.\n            timeDistribute (dense w))\n          (repeatT zeros, input)\n      in timedCategorical masks predictions gold\n\npadding :: GHC.Int32\npadding = 10\n\nmain :: IO ()\nmain = do\n  generateFile \"aggr.py\" (compile @512 defaultOptions (model @12 @21))\n  putStrLn \"done!\"\n\n-- >>> main\n-- Parameters (total 134300):\n-- dense_bias: T [12] tf.float32\n-- dense_w: T [160,12] tf.float32\n-- w1_o_b: T [160] tf.float32\n-- w1_o_w: T [172,160] tf.float32\n-- w1_c_b: T [160] tf.float32\n-- w1_c_w: T [172,160] tf.float32\n-- w1_i_b: T [160] tf.float32\n-- w1_i_w: T [172,160] tf.float32\n-- w1_f_b: T [160] tf.float32\n-- w1_f_w: T [172,160] tf.float32\n-- embs: T [12,12] tf.float32\n-- y: T [512,21] tf.int32\n-- x: T [512,21] tf.int32\n-- done!\n\n\n(|>) :: ∀ a b. a -> b -> (a, b)\n(|>) = (,)\ninfixr |>\n\n\n-- Local Variables:\n-- dante-repl-command-line: (\"nix-shell\" \".styx/shell.nix\" \"--pure\" \"--run\" \"cabal repl\")\n-- End:\n\n"
  },
  {
    "path": "examples/mnist/MNIST.hs",
    "content": "{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE ApplicativeDo #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UnicodeSyntax #-}\n{-# LANGUAGE NoStarIsType #-}\nmodule MNIST where\n\nimport TypedFlow\nimport TypedFlow.Python\n\natShape :: forall s t. T s t -> T s t\natShape x = x\n\nmnist :: Gen (Model '[784] Float32 '[10] '[10] '[] Float32)\nmnist = do\n  filters1 <- parameterDefault \"f1\"\n  filters2 <- parameterDefault \"f2\"\n  w1 <- parameterDefault \"w1\"\n  w2 <- parameterDefault \"w2\"\n  return $ \\input gold ->\n    let nn = dense @10 w2                       .\n             relu . dense @1024 w1              .\n             reshape @'[7 * 7 * 64]             .\n             maxPool2D @2 @2                    .\n             relu . conv @64 @'[5,5] filters2   .\n             maxPool2D @2 @2                    .\n             atShape @'[28,28,32]               .\n             relu . conv @32 @'[5,5] filters1   .\n             reshape @'[28,28,1]\n        logits = nn input\n\n    in categoricalDistribution logits gold\n\nmain :: IO ()\nmain = do\n  generateFile \"mnist_model.py\" (compile @100 defaultOptions mnist)\n  putStrLn \"done!\"\n\n-- >>> main\n-- Parameters (total 3274634):\n-- f1_filters: T [5, 5, 1, 32] tf.float32\n-- f1_biases: T [32] tf.float32\n-- f2_filters: T [5, 5, 32, 64] tf.float32\n-- f2_biases: T [64] tf.float32\n-- w1_w: T [3136, 1024] tf.float32\n-- w1_bias: T [1024] tf.float32\n-- w2_w: T [1024, 10] tf.float32\n-- w2_bias: T [10] tf.float32\n-- done!\n\n\n-- Local Variables:\n-- dante-repl-command-line: (\"nix-shell\" \".styx/shell.nix\" \"--pure\" \"--run\" \"cabal repl\")\n-- End:\n\n"
  },
  {
    "path": "examples/mnist/Makefile",
    "content": "test: mnist_model.py main.py\n\tnix-shell ../seq2seq/shell.nix --run \"python main.py\"\n\nmnist_model.py: MNIST.hs\n\tnix-shell ../../.styx/shell.nix --run \"ghci -i../.. MNIST.hs -e main\"\n\n"
  },
  {
    "path": "examples/mnist/main.py",
    "content": "import sys\nsys.path.append('../..') # so we can see the rts.\n\nimport typedflow_rts as tyf\nimport tensorflow as tf\nimport numpy as np\nfrom mnist_model import mkModel,runModel\nimport os\n\n# comment out if you don't have CUDA\ntyf.cuda_use_one_free_device()\n\noptimizer = tf.keras.optimizers.Adam(1e-4)\n\n# import tfds.image_classification.MNIST as mnist # need to package tfds\n\ndef train_generator(batch_size):\n    for _ in range(1000):\n        # (x,y) = mnist.batch(100)\n        yield {\"x\":np.zeros((100,784), dtype=np.float32), # FIXME\n               \"y\":np.zeros((100,10), dtype=np.float32)\n               }\n\n\nmodel = mkModel()\n\ntyf.train(optimizer,model,runModel,train_generator)\n"
  },
  {
    "path": "examples/mnist/mnist_model.py",
    "content": "import tensorflow as tf\ndef mkModel():\n  #shape: [25, 32]\n  var10000=tf.random.uniform([25, 32],\n                             minval=-0.32444283,\n                             maxval=0.32444283,\n                             dtype=tf.float32) # 0\n  #shape: [5, 5, 1, 32]\n  var10001=tf.reshape(var10000, [5, 5, 1, 32])\n  var10002=tf.Variable(name=\"f1_filters\", trainable=True, initial_value=var10001)\n  #shape: []\n  var10003=tf.constant(0.1, shape=[], dtype=tf.float32)\n  #shape: [32]\n  var10004=ERROR:BroadcastT(var10003)\n  #shape: [32]\n  var10005=tf.reshape(var10004, [32])\n  var10006=tf.Variable(name=\"f1_biases\", trainable=True, initial_value=var10005)\n  #shape: [800, 64]\n  var10007=tf.random.uniform([800, 64],\n                             minval=-8.3333336e-2,\n                             maxval=8.3333336e-2,\n                             dtype=tf.float32) # 1\n  #shape: [5, 5, 32, 64]\n  var10008=tf.reshape(var10007, [5, 5, 32, 64])\n  var10009=tf.Variable(name=\"f2_filters\", trainable=True, initial_value=var10008)\n  #shape: [64]\n  var10010=tf.reshape(var10004, [64])\n  var10011=tf.Variable(name=\"f2_biases\", trainable=True, initial_value=var10010)\n  #shape: [3136, 1024]\n  var10012=tf.random.uniform([3136, 1024],\n                             minval=-3.7977725e-2,\n                             maxval=3.7977725e-2,\n                             dtype=tf.float32) # 2\n  var10013=tf.Variable(name=\"w1_w\", trainable=True, initial_value=var10012)\n  #shape: [1024]\n  var10014=tf.random.truncated_normal([1024], stddev=0.1, dtype=tf.float32) # 3\n  var10015=tf.Variable(name=\"w1_bias\", trainable=True, initial_value=var10014)\n  #shape: [1024, 10]\n  var10016=tf.random.uniform([1024, 10],\n                             minval=-7.61755e-2,\n                             maxval=7.61755e-2,\n                             dtype=tf.float32) # 4\n  var10017=tf.Variable(name=\"w2_w\", trainable=True, initial_value=var10016)\n  #shape: [10]\n  var10018=tf.random.truncated_normal([10], stddev=0.1, dtype=tf.float32) # 5\n  var10019=tf.Variable(name=\"w2_bias\", trainable=True, initial_value=var10018)\n  return {\"batch_size\":100,\n          \"parameters\":[ var10002\n          , var10006\n          , var10009\n          , var10011\n          , var10013\n          , var10015\n          , var10017\n          , var10019 ],\n          \"paramsdict\":{\"f1_filters\":var10002,\n                        \"f1_biases\":var10006,\n                        \"f2_filters\":var10009,\n                        \"f2_biases\":var10011,\n                        \"w1_w\":var10013,\n                        \"w1_bias\":var10015,\n                        \"w2_w\":var10017,\n                        \"w2_bias\":var10019}}\n@tf.function\ndef runModel_fn(training_placeholder,\n                f1_filters,\n                f1_biases,\n                f2_filters,\n                f2_biases,\n                w1_w,\n                w1_bias,\n                w2_w,\n                w2_bias,\n                x,\n                y):\n  #shape: [100, 10]\n  var10020=y\n  #shape: [100, 784]\n  var10021=x\n  #shape: [100, 28, 28, 1]\n  var10022=tf.reshape(var10021, [100, 28, 28, 1])\n  #shape: [5, 5, 1, 32]\n  var10023=f1_filters\n  #shape: [100, 28, 28, 32]\n  var10024=tf.nn.convolution(var10022, var10023, padding=\"SAME\", data_format=\"NHWC\")\n  #shape: [100, 784, 32]\n  var10025=tf.reshape(var10024, [100, 784, 32])\n  #shape: [32]\n  var10026=f1_biases\n  #shape: [784, 32]\n  var10027=tf.broadcast_to(tf.reshape(var10026, [1, 32]), [784, 32])\n  #shape: [100, 784, 32]\n  var10028=tf.broadcast_to(tf.reshape(var10027, [1, 784, 32]), [100, 784, 32])\n  #shape: [100, 784, 32]\n  var10029=tf.add(var10025, var10028)\n  #shape: [100, 28, 28, 32]\n  var10030=tf.reshape(var10029, [100, 28, 28, 32])\n  #shape: [100, 28, 28, 32]\n  var10031=tf.nn.relu(var10030)\n  #shape: [100, 28, 28, 32]\n  var10032=tf.reshape(var10031, [100, 28, 28, 32])\n  #shape: [100, 14, 14, 32]\n  var10033=tf.nn.pool(var10032, [2, 2], \"MAX\", strides=[2, 2], padding=\"SAME\")\n  #shape: [100, 14, 14, 32]\n  var10034=tf.reshape(var10033, [100, 14, 14, 32])\n  #shape: [5, 5, 32, 64]\n  var10035=f2_filters\n  #shape: [100, 14, 14, 64]\n  var10036=tf.nn.convolution(var10034, var10035, padding=\"SAME\", data_format=\"NHWC\")\n  #shape: [100, 196, 64]\n  var10037=tf.reshape(var10036, [100, 196, 64])\n  #shape: [64]\n  var10038=f2_biases\n  #shape: [196, 64]\n  var10039=tf.broadcast_to(tf.reshape(var10038, [1, 64]), [196, 64])\n  #shape: [100, 196, 64]\n  var10040=tf.broadcast_to(tf.reshape(var10039, [1, 196, 64]), [100, 196, 64])\n  #shape: [100, 196, 64]\n  var10041=tf.add(var10037, var10040)\n  #shape: [100, 14, 14, 64]\n  var10042=tf.reshape(var10041, [100, 14, 14, 64])\n  #shape: [100, 14, 14, 64]\n  var10043=tf.nn.relu(var10042)\n  #shape: [100, 14, 14, 64]\n  var10044=tf.reshape(var10043, [100, 14, 14, 64])\n  #shape: [100, 7, 7, 64]\n  var10045=tf.nn.pool(var10044, [2, 2], \"MAX\", strides=[2, 2], padding=\"SAME\")\n  #shape: [100, 3136]\n  var10046=tf.reshape(var10045, [100, 3136])\n  #shape: [3136, 1024]\n  var10047=w1_w\n  #shape: [100, 1024]\n  var10048=tf.matmul(var10046, var10047)\n  #shape: [100, 1024]\n  var10049=tf.reshape(var10048, [100, 1024])\n  #shape: [1024]\n  var10050=w1_bias\n  #shape: [100, 1024]\n  var10051=tf.broadcast_to(tf.reshape(var10050, [1, 1024]), [100, 1024])\n  #shape: [100, 1024]\n  var10052=tf.add(var10049, var10051)\n  #shape: [100, 1024]\n  var10053=tf.nn.relu(var10052)\n  #shape: [100, 1024]\n  var10054=tf.reshape(var10053, [100, 1024])\n  #shape: [1024, 10]\n  var10055=w2_w\n  #shape: [100, 10]\n  var10056=tf.matmul(var10054, var10055)\n  #shape: [100, 10]\n  var10057=tf.reshape(var10056, [100, 10])\n  #shape: [10]\n  var10058=w2_bias\n  #shape: [100, 10]\n  var10059=tf.broadcast_to(tf.reshape(var10058, [1, 10]), [100, 10])\n  #shape: [100, 10]\n  var10060=tf.add(var10057, var10059)\n  #shape: [100]\n  var10061=tf.nn.softmax_cross_entropy_with_logits(labels=var10020, logits=var10060)\n  #shape: [100]\n  var10062=tf.reshape(var10061, [100])\n  #shape: []\n  var10063=tf.reduce_mean(var10062, axis=0)\n  #shape: []\n  var10064=tf.constant(0.0, shape=[], dtype=tf.float32)\n  #shape: [1]\n  var10065=tf.broadcast_to(tf.reshape(var10064, [1]), [1])\n  #shape: []\n  var10066=tf.reshape(var10065, [])\n  #shape: []\n  var10067=tf.add(var10063, var10066)\n  #shape: [100]\n  var10068=tf.argmax(var10060, axis=1, output_type=tf.int32)\n  #shape: [100]\n  var10069=tf.argmax(var10020, axis=1, output_type=tf.int32)\n  #shape: [100]\n  var10070=tf.equal(var10068, var10069)\n  #shape: [100]\n  var10071=tf.cast(var10070, tf.float32)\n  #shape: [100]\n  var10072=tf.reshape(var10071, [100])\n  #shape: []\n  var10073=tf.reduce_mean(var10072, axis=0)\n  #shape: [100, 10]\n  var10074=tf.reshape(var10060, [100, 10])\n  #shape: [100, 10]\n  var10075=tf.nn.softmax(var10074, axis=1)\n  #shape: [100, 10]\n  var10076=tf.reshape(var10075, [100, 10])\n  return {\"loss\":var10067, \"accuracy\":var10073, \"y_\":var10076}\nrunModel = {\"function\":runModel_fn,\n            \"batched\":True,\n            \"placeholders\":{\"x\":{\"shape\":[100, 784], \"dtype\":tf.float32},\n                            \"y\":{\"shape\":[100, 10], \"dtype\":tf.float32}}}"
  },
  {
    "path": "examples/seq2seq/GenTr.hs",
    "content": "import Control.Applicative\nimport Test.QuickCheck.Gen\nimport Data.List\nimport Data.Array\ndata Abs a = Bin a (Abs a) (Abs a) | Leaf a deriving Show\n\ntype Method a =  a -> [a] -> [a] -> [a]\n\nparens :: String -> String\nparens xs = \"(\" ++ xs ++ \")\"\n\npreorder :: Char -> [Char] -> [Char] -> String\npreorder x l r =  (x : l ++ r)\npostorder :: Char -> [Char] -> [Char] -> String\npostorder x l r =  (l ++ r ++ [x])\nreversePO :: Char -> [Char] -> [Char] -> String\nreversePO x l r =  (x : r ++ l)\n\nlinearize _ (Leaf x) = [x]\nlinearize m (Bin x l r) = parens (m x (lin l) (lin r))\n  where lin = linearize m\n\nmkMethods :: Eq a => [(a->Bool,Method a)] -> Method a\nmkMethods ms x = case find (\\(p,_) -> p x) ms of\n  Just (_,m) -> m x\n  Nothing -> error \"no applicable linearization method\"\n\nlinPO :: Abs Char -> [Char]\nlinPO = linearize (mkMethods [(const True,preorder)])\n\nlin1 :: Abs Char -> [Char]\nlin1 = linearize (mkMethods [(\\x -> x < '3',reversePO),(const True,preorder)])\n\n\nex :: Abs Char\nex = Bin 'a' (Bin '1' (Leaf 'b') (Leaf 'c')) (Leaf 'd')\n\nguard :: Alternative f => Bool -> f a -> f a\nguard True x = x\nguard False _ = empty\n\n\narb :: Gen (Abs Char)\narb = sized $ \\n -> do\n  oneof (take (max 1 n) [(Leaf <$> elements ['a'..'e'])\n                        ,resize (n-1) (Bin <$> elements ['0'..'4'] <*> arb <*> arb)])\n\narbOkSize :: Gen (Abs Char)\narbOkSize = do\n  x <- resize 6 arb\n  let xx = linPO x\n  if (length xx > 2 && length xx < 22)\n    then return x\n    else arbOkSize\n\nmySample :: Int -> IO [Abs Char]\nmySample n = generate (sequence $ replicate n arbOkSize)\n\nshowEx :: Abs Char -> String\nshowEx x = linPO x ++ \"\\t\" ++ lin1 x\n\n\ntest :: IO ()\ntest = mapM_ putStrLn . map showEx =<< mySample 10\n\n\nmain :: IO ()\nmain = writeFile \"synthtrees.txt\" . unlines . map showEx =<< mySample 100000\n"
  },
  {
    "path": "examples/seq2seq/Makefile",
    "content": "test: s2s.py synthtrees.txt main.py\n\tnix-shell --run \"python main.py\"\n\ns2s.py: Seq2Seq.hs\n\tnix-shell ../../.styx/shell.nix --run \"ghci -i../.. Seq2Seq.hs -e main\"\n\nsynthtrees.txt: GenTr.hs\n\tnix-shell ../../.styx/shell.nix --run \"ghc --make GenTr\"\n\t./GenTr\n"
  },
  {
    "path": "examples/seq2seq/Seq2Seq.hs",
    "content": "{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE TypeApplications #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE UnicodeSyntax #-}\n\nmodule Main  where\n\nimport TypedFlow\nimport TypedFlow.Python\n\nmkLSTM :: ∀ n x w.\n       KnownNat x => KnownNat n => KnownBits w\n       => String\n       -> Gen (RnnCell w '[ '[n], '[n]] (Tensor '[x] (Flt w))\n               (Tensor '[n] (Flt w)))\nmkLSTM pName = do\n  params <- parameterDefault pName\n  drp1 <- mkDropout (DropProb 0.05)\n  rdrp1 <- mkDropouts (DropProb 0.05)\n  return (timeDistribute drp1 .-. onStates rdrp1 (lstm params))\n\nencoder :: forall (lstmSize :: Nat) (vocSize :: Nat) (n :: Nat)  w. \n           KnownNat lstmSize => KnownNat vocSize\n        => (KnownNat n) => KnownBits w\n        => String\n        -> Gen\n        (\n        T '[] Int32 -- length\n        -> Tensor '[n] Int32 -> \n          ((HTV (Flt w) '[ '[lstmSize], '[lstmSize] ], Tensor '[n, lstmSize] (Flt w))))\nencoder prefix = do\n  embs <- parameterDefault (prefix++\"embs\")\n  lstm1 <- mkLSTM (prefix++\"lstm1\")\n  return $ \\len input ->\n    runRnn\n       (iterateWithCull len (timeDistribute (embedding @vocSize @vocSize embs) .-. lstm1))\n       (repeatT zeros, input)\n\ndecoder :: forall (lstmSize :: Nat) (n :: Nat) (outVocabSize :: Nat) (d::Nat) w.\n                 KnownNat lstmSize => KnownNat d => (KnownNat outVocabSize, KnownNat n) => KnownBits w =>\n                 String\n                 -> Gen (\n                 T '[] Int32 -- ^ length\n                 -> T '[n, d] (Flt w) -- todo: consider a larger size for the output string\n                 -> HTV (Flt w) '[ '[lstmSize], '[lstmSize] ]\n                 -> Tensor '[n] Int32\n                 -> Tensor '[n, outVocabSize] (Flt w))\ndecoder prefix = do\n  -- note: for an intra-language translation the embeddings can be shared easily.\n  projs <- parameterDefault (prefix++\"proj\")\n  lstm1 <- mkLSTM (prefix++\"lstm1\")\n  embs <- parameterDefault \"embs\"\n  w1 <- parameter' (prefix++\"att1\") =<< glorotUniform\n  return $ \\ lens hs thoughtVectors targetInput ->\n    let attn = uniformAttn (multiplicativeScoring w1) lens hs -- NOTE: attention on the left-part of the input.\n        (_sFinal,outFinal) = simpleRnn\n          ((timeDistribute (embedding @outVocabSize @outVocabSize embs)\n                         .-.\n                         attentiveWithFeedback attn lstm1\n                         .-.\n                         timeDistribute (dense projs)))\n          ((F zeros :* thoughtVectors), targetInput)\n    in outFinal\n\n\nseq2seq :: forall (vocSize :: Nat) (n :: Nat).\n                 KnownNat vocSize => (KnownNat n)\n        => Gen (Placeholders\n                     '[ '(\"tgt_weights\", '[n], Float32),\n                        '(\"src_in\", '[n], Int32),\n                        '(\"src_len\", '[],  Int32),\n                        '(\"tgt_in\", '[n], Int32),\n                        '(\"tgt_out\", '[n], Int32)] ->\n                ModelOutput Float32 '[n, vocSize] '[])\nseq2seq  = do\n  enc <- encoder @256 @vocSize \"enc\"\n  dec <- decoder \"dec\"\n  return $ \\(PHT masks :* PHT input :* PHT inputLen :* PHT tgtIn :* PHT tgtOut :* Unit) ->\n    let (VecPair t1 t2,h) = enc inputLen input\n        y_ = dec inputLen h (VecPair t1 t2) tgtIn\n    in timedCategorical masks y_ tgtOut\n\n\n\n\nmain :: IO ()\nmain = generateFile \"s2s.py\" (compileGen @256\n                               defaultOptions {maxGradientNorm = Just 5}\n                               (stateless <$> seq2seq @15 @22))\n\n-- >>> main\n-- Parameters (total 889041):\n-- decatt1: T [256,256] tf.float32\n-- embs: T [15,15] tf.float32\n-- declstm1_o_bias: T [256] tf.float32\n-- declstm1_o_w: T [527,256] tf.float32\n-- declstm1_c_bias: T [256] tf.float32\n-- declstm1_c_w: T [527,256] tf.float32\n-- declstm1_i_bias: T [256] tf.float32\n-- declstm1_i_w: T [527,256] tf.float32\n-- declstm1_f_bias: T [256] tf.float32\n-- declstm1_f_w: T [527,256] tf.float32\n-- decproj_bias: T [15] tf.float32\n-- decproj_w: T [256,15] tf.float32\n-- enclstm1_o_bias: T [256] tf.float32\n-- enclstm1_o_w: T [271,256] tf.float32\n-- enclstm1_c_bias: T [256] tf.float32\n-- enclstm1_c_w: T [271,256] tf.float32\n-- enclstm1_i_bias: T [256] tf.float32\n-- enclstm1_i_w: T [271,256] tf.float32\n-- enclstm1_f_bias: T [256] tf.float32\n-- enclstm1_f_w: T [271,256] tf.float32\n-- encembs: T [15,15] tf.float32\n\n-- Local Variables:\n-- dante-repl-command-line: (\"nix-shell\" \".styx/shell.nix\" \"--pure\" \"--run\" \"cabal repl\")\n-- End:\n\n"
  },
  {
    "path": "examples/seq2seq/main.py",
    "content": "import sys\nsys.path.append('../..') # so we can see the rts.\n\nimport typedflow_rts as tyf\nimport tensorflow as tf\nimport numpy as np\nfrom s2s import mkModel\nimport os\nimport math\nimport random\n\n\n# comment out if you don't have CUDA\ntyf.cuda_use_one_free_device()\n\nchars = sorted(list(\"()01234abcde^$ \"))\n\nprint('total chars:', len(chars))\nchar_indices = dict((c, i) for i, c in enumerate(chars))\nindices_char = dict((i, c) for i, c in enumerate(chars))\n\nMAXLEN = 22\ndef pad(ws): return (ws + ' '*(MAXLEN - len(ws)))\n\ndef encode(s):\n    # print (\"proun\", s)\n    return np.array([char_indices[c] for c in s])\n\ndef decode(s): return \"\".join([indices_char[c] for c in list(s)])\n\ndef pad_right(sentence): return (MAXLEN - len(sentence)) * \" \" + sentence\ndef pad_left(sentence): return  sentence + (MAXLEN - len(sentence)) * \" \"\n\ndef source_input_conversion(s):\n    return encode(pad_left(s))\n\ndef target_input_conversion(sentence):\n    return encode(pad_left(\"^\"+sentence))\n\ndef target_output_conversion(sentence):\n    return encode(pad_left(sentence+\"$\"))\n\ndef sentence_target_weights(sentence):\n    l = len(sentence)\n    w = (l + 1) * [1] + (MAXLEN - (l + 1)) * [0]\n    return np.array(w)\n\ndef map(f,l):\n    return [f(x) for x in l]\n\ndef make_examples(l):\n    (l1,l2) = zip(*l)\n    return {\"src_in\":map(source_input_conversion,l1),\n            \"src_len\":map(len,l1),\n            \"tgt_in\":map(target_input_conversion,l2),\n            \"tgt_out\":map(target_output_conversion,l2),\n            \"tgt_weights\":map(sentence_target_weights,l2)}\n\ndef s2s_generator(src_len,src_in,tgt_in,tgt_out,tgt_weights):\n    def gen(bs):\n      for i in range(0, bs*(len(src_in)//bs), bs):\n          # print ({\"src_len\":src_len[i:i+bs], \"src_in\":src_in[i:i+bs], \"tgt_in\":tgt_in[i:i+bs], \"tgt_out\":tgt_out[i:i+bs], \"tgt_weights\":tgt_weights[i:i+bs]})\n          yield {\"src_len\":src_len[i:i+bs],\n                 \"src_in\":src_in[i:i+bs],\n                 \"tgt_in\":tgt_in[i:i+bs],\n                 \"tgt_out\":tgt_out[i:i+bs],\n                 \"tgt_weights\":tgt_weights[i:i+bs]}\n    return gen\n\n\ndef my_sample(l,n):\n    return list(random.sample(l,min(n,len(l))))\n\nprint(\"Reading sentences...\")\nall_sentences = [l.strip().split(\"\\t\") for l in open(\"synthtrees.txt\").readlines()]\n\nval_set = make_examples(all_sentences[:2000])\ntrain_set = make_examples(all_sentences[2000:])\n\nprint(\"Loading model\")\nmodel = mkModel(tf.train.AdamOptimizer())\nsess = tf.Session()\nsaver = tf.train.Saver()\n\ndef printer(x):\n    (p,y,h) = x\n    print(\"Prob\", p, decode(y),h)\n\n\ndef translate(s):\n    r = tyf.beam_translate(sess,model,14,\n                           source_input_conversion(s),\n                           len(s),\n                           char_indices[\"^\"], char_indices[\"$\"],\n                           printer)\n    for x in r: printer(x)\n\ndef translate_cb(values):\n    if values[\"epoch\"] % 10 == 0:\n        save_path = saver.save(sess, \"model.ckpt\")\n        translate(\"(1(3cb)b)\")\n        print (\"Desired:\", \"(1b(3cb))\")\n        return False\n\ntyf.initialize_params(sess,model)\ntrain_stats = tyf.train(sess,\n                        model,\n                        s2s_generator(**train_set),\n                        valid_generator = s2s_generator(**val_set),\n                        epochs=5000,\n                        callbacks=[tyf.StopWhenAccurate(.01), translate_cb])\n\ntranslate(\"(1(3cb)b)\")\ntranslate(\"(1(2c(3e(4(1cb)b)))c)\")\n"
  },
  {
    "path": "examples/seq2seq/shell.nix",
    "content": "{ bootstrap ? import <nixpkgs> {} }:\nlet nixpkgs_source = fetchTarball https://github.com/NixOS/nixpkgs/archive/nixos-20.03.tar.gz;\n    # nixpkgs_source = fetchTarball https://github.com/NixOS/nixpkgs/archive/4cf0b6ba5d5ab5eb20a88449e0612f4dad8e4c29.tar.gz;\n    # nixpkgs_source = bootstrap.fetchFromGitHub { # for safety of checking the hash\n    #    owner = \"jyp\";\n    #    repo = \"nixpkgs\";\n    #    rev = \"6b911c2d99ad116fca338fc26de86b8859079322\";\n    #    sha256 = \"1bhwjkynya653mvpc4wwqks6kxnc06gyw6sbpwp8dbyr444ms4bd\";\n    #  };\n    # nixpkgs_source = ~/repo/nixpkgs;\n\nin with (import nixpkgs_source {}).pkgs;\nlet py = (pkgs.python37.withPackages (ps: [ps.tensorflow-bin_2 ps.nltk]));\n\nin pkgs.stdenv.mkDerivation {\n  name = \"my-env-0\";\n  buildInputs = [ py ];\n}\n\n"
  },
  {
    "path": "styx.yaml",
    "content": "local-packages:\n  typedflow:\n    location: .\n\nnix-deps:\n    - QuickCheck\n    - hscolour\n\n# non-haskell-deps:\n#     - glibcLocales\n\nnixpkgs:\n  # commit: 80812af9e46167e3104038f2af6de251f90823a8\n  # sha256: 0b718zkn5lhy71pyp0klbz7w872zck0ljqfk17f0b56k3rlvp1sy\n  url: https://github.com/NixOS/nixpkgs/archive/nixos-21.05.tar.gz\n"
  },
  {
    "path": "typedflow.cabal",
    "content": "name:           typedflow\nversion:        0.9\ncategory:       Deep Learning\nsynopsis:       Typed frontend to TensorFlow and higher-order deep learning\ndescription: TypedFlow is a typed, higher-order frontend to TensorFlow and a high-level library for deep-learning.\n             .\n             The main design principles are:\n             .\n               - To make the parameters of layers explicit. This choice makes sharing of parameters explicit and allows to implement \"layers\" as pure functions.\n             .\n               - To provide as precise as possible types. Functions are explicit about the shapes and elements of the tensors that they manipulate (they are often polymorphic in shapes and elements though.)\n             .\n               - To let combinators be as transparent as possible. If a NN layers is a simple tensor transformation it will be exposed as such.\nlicense:        LGPL-3\nlicense-file:   LICENSE\nauthor:         Jean-Philippe Bernardy\nmaintainer:     jean-philippe.bernardy@gu.se\nCabal-Version:  >= 1.12\nbuild-type:     Simple\nsource-repository head\n  type:     git\n  location: git@github.com:GU-CLASP/TypedFlow.git\n\nlibrary\n  default-language: Haskell2010\n  build-depends:\n    base==4.*,\n    ghc-typelits-knownnat,\n    prettyprinter,\n    mtl,\n    containers\n    -- ,tensorflow-opgen, tensorflow, tensorflow-core-ops, tensorflow-ops\n\n  exposed-modules:\n       TypedFlow,\n       TypedFlow.Layers,\n       TypedFlow.Layers.Core,\n       TypedFlow.Layers.RNN,\n       TypedFlow.Layers.RNN.Base,\n       TypedFlow.Layers.RNN.Cells,\n       TypedFlow.Layers.RNN.Attention,\n       TypedFlow.Learn,\n       TypedFlow.Models.Topic,\n       TypedFlow.Models.Transformer,\n       TypedFlow.Python,\n       TypedFlow.TF,\n       TypedFlow.Types,\n       TypedFlow.Types.Proofs\n\n  other-modules:\n        TypedFlow.Memo\n        TypedFlow.Memo2\n        TypedFlow.Abstract\n        TypedFlow.Broadcast\n"
  },
  {
    "path": "typedflow_rts.py",
    "content": "import tensorflow as tf\nimport numpy as np\nimport sys\nfrom time import time\nimport os\nimport random\n\n###############################################################\n# Devices\n###############################################################\n\n\ndef cuda_use_device(n):\n    \"\"\"Attempt to use a given CUDA device by setting the appropriate environment variables\"\"\"\n    os.environ[\"CUDA_DEVICE_ORDER\"]= \"PCI_BUS_ID\"\n    if os.environ.get(\"CUDA_VISIBLE_DEVICES\") is None:\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(n)\n\ndef find_free_cuda_device():\n    currentGPU = -1\n    gpuMemory=dict()\n    gpuUtil=dict()\n    for line in os.popen(\"nvidia-smi -q\"):\n        fields = list(map(lambda x: x.strip(), line.split(\":\")))\n        k = fields[0]\n        if k == \"Minor Number\":\n            currentGPU += 1\n            gpuMemory[currentGPU] = 0\n        elif k == \"Used GPU Memory\":\n            gpuMemory[currentGPU] = int(fields[1][:-4]) # last characters are \" MiB\"\n        elif k == \"Gpu\":\n            gpuUtil[currentGPU] = fields[1] # last characters are \" %\"\n    minUse = min(gpuMemory.values())\n    freeGpus = [g for g in gpuMemory.keys() if gpuMemory[g] == minUse]\n    if freeGpus == []:\n        print(\"No free GPU could be found.\")\n        assert False\n    else:\n        result = random.choice(freeGpus)\n        print (\"Found device\",result,\"currently used at\",gpuUtil[result],\"and with\",gpuMemory[result],\"MB taken.\")\n        return result\n\ndef cuda_use_one_free_device():\n    \"\"\"Attempt to use a free CUDA device by setting the appropriate environment variables\"\"\"\n    cuda_use_device(find_free_cuda_device())\n\n###############################################################\n# Generators\n###############################################################\n\ndef bilist_generator(l):\n    \"\"\"\n    Given a pair of x and y (each being a list or a np array) and a\n    batch size, return a generator function which will yield the input\n    in bs-sized chunks. Attention: if the size of the input is not\n    divisible by bs, then the remainer will not be fed. Consider\n    shuffling the input.\n    \"\"\"\n    (l0,l1) = l\n    def gen(bs):\n        if len(l0) == 0:\n            return\n        for i in range(0, bs*(len(l0)//bs), bs):\n            yield {\"x\":l0[i:i+bs],\"y\":l1[i:i+bs]}\n    return gen\n\n\ndef bilist_generator_transposed(model,l):\n    '''\n    Given a pair of l=(x,y) (both x,y being a list or a np array) and a\n    batch size, return a generator function which will yield the input\n    in bs*maxlen-sized chunks. This generator is intended to be used for\n    stateful language models. That is, batch sequencing corresponds to \n    '''\n    (batch_size,maxlen) = model[\"x\"].shape\n    (xs,ys) = l\n    num_items = len(xs) // (batch_size*maxlen)\n    x = np.zeros(shape=(num_items,batch_size,maxlen))\n    y = np.zeros(shape=(num_items,batch_size,maxlen))\n    for i in range(num_items):\n        for j in range(batch_size):\n            for k in range(maxlen):\n                x[i][j][k] = xs[k+j*(num_items*maxlen)+i*maxlen]\n                y[i][j][k] = ys[k+j*(num_items*maxlen)+i*maxlen]\n    def gen(_bs):\n        nonlocal num_items, x, y\n        for i in range(num_items):\n            yield {\"x\":x[i],\"y\":y[i]}\n    return gen\n\ndef dict_generator (xs):\n    k0 = next (iter (xs.keys())) # at least one key is needed\n    total_len = len(xs[k0])\n\n    def gen(bs):\n      for i in range(0, bs*(total_len//bs), bs):\n        yield dict((k,xs[k][i:i+bs]) for k in xs)\n\n    return gen\n\n\ndef initialize_params (session,model):\n    '''Initialize the learnable parameters of the model'''\n    # it'd be nice to do:\n\n    # session.run(tf.variables_initializer(model[\"params\"]))\n\n    # However this does not initialize the optimizer's variables. So,\n    # instead we do:\n\n    session.run(tf.local_variables_initializer())\n    session.run(tf.global_variables_initializer())\n\ndef train (optimizer, model_static, model_fn,\n           train_generator=bilist_generator(([],[])),\n           valid_generator=bilist_generator(([],[])),\n           epochs=100,\n           callbacks=[],\n           extraVectors=[]):\n    '''\n    Train the given model.\n\n    train_generator: training data\n\n    valid_generator: validation data\n\n    epochs: number of epochs\n\n    callbacks: list of callbacks.\n      Each callback receives an epoch entry (see below). If it returns False then the training is aborted.\n\n    extraVectors: list of extra vectors to pass to session.run when training.\n\n    modelPrefix: in case of a multitask/multimodel, give the prefix of the model to use.\n\n    This function returns a list of epoch entries. Each entry is a dictionary with:\n     - \"epoch\": current epoch\n     - \"val\" and \"train\": dictionaries with\n        - \"loss\", \"accuracy\", \"error_rate\", time\", \"start_time\", \"end_time\"\n    '''\n    batch_size = model_static[\"batch_size\"]\n    train_vars = model_static[\"parameters\"]\n    placeholders_info = model_fn[\"placeholders\"]\n    stats = []\n    def halfEpoch(isTraining):\n        totalAccur = 0\n        totalLoss = 0\n        n = 0\n        print (\"Training\" if isTraining else \"Validation\", end=\"\")\n        start_time = time()\n        for inputs in train_generator(batch_size) if isTraining else valid_generator(batch_size):\n            cast_inputs = dict((k,tf.cast(inputs[k], placeholders_info[k][\"dtype\"])) for k in placeholders_info)\n            # the above forces inputs to be tensors. (It's convenient to pass just lists here)\n            print(\".\",end=\"\")\n            sys.stdout.flush()\n            with tf.GradientTape() as tape:\n                results = model_fn[\"function\"](tf.constant(isTraining, shape=[]), **{**(model_static[\"paramsdict\"]), **cast_inputs})\n                loss = results[\"loss\"]\n                accur = results[\"accuracy\"]\n                if isTraining:\n                   grads = tape.gradient(loss, train_vars)\n                   optimizer.apply_gradients(zip(grads, train_vars))\n                n+=1\n                totalLoss += loss\n                totalAccur += accur\n        end_time = time()\n        totalAccur = totalAccur.numpy()\n        totalLoss = totalLoss.numpy()\n        if n > 0:\n            avgLoss = totalLoss / float(n)\n            avgAccur = totalAccur / float(n)\n            print(\".\")\n            print (\"Time=%.1f\" % (end_time - start_time), \"loss=%g\" % avgLoss, \"accuracy=%.3f\" % avgAccur)\n            return {\"loss\":avgLoss,\"accuracy\":avgAccur,\"time\":(end_time - start_time),\"error_rate\":1-avgAccur,\"start_time\":start_time}\n        else:\n            print (\"No data\")\n            return {\"loss\":0,\"accur\":0,\"time\":0,\"error_rate\":0,\"start_time\":0}\n\n    for e in range(epochs):\n        print (\"Epoch {0}/{1}\".format(e, epochs))\n        tr = halfEpoch(True)\n        va = halfEpoch(False)\n        epoch_stats = {\"train\":tr, \"val\":va, \"epoch\":e}\n        stats.append(epoch_stats)\n        if any(c(epoch_stats) for c in callbacks):\n            break\n    return stats\n\ndef StopWhenValidationGetsWorse(patience = 1):\n    '''Return a callback which stops training if validation loss gets worse.'''\n    bestLoss = 10000000000\n    p = patience\n    def callback(values):\n        nonlocal bestLoss, p, patience\n        newLoss = values[\"val\"][\"loss\"]\n        if newLoss > bestLoss:\n            p -= 1\n        else:\n            bestLoss = newLoss\n            p = patience\n        if p <= 0:\n            return True\n        return False\n    return callback\n\ndef StopWhenAccurate(phase=\"val\",error_rate = .01):\n    '''Return a callback which stops training if error rate drops below 1%'''\n    def callback(values):\n        nonlocal error_rate\n        return values[phase][\"error_rate\"] < error_rate\n    return callback\n\ndef Every(n,f):\n    '''Return a callback which calls its argument every n epochs'''\n    def callback(values):\n        nonlocal n,f\n        if values[\"epoch\"] % n == (n-1):\n            return f(values)\n        else:\n            return False\n    return callback\n\ndef Save(sess,saver,ckptfile):\n    def callback(values):\n        nonlocal sess,saver\n        print(\"Saving to\",ckptfile)\n        saver.save(sess, ckptfile)\n        return False\n    return callback\n\n################################################################################################\n# Prediction and evaluation\n\n\ndef evaluate (model_static, model_fn, xs, result=\"y_\"):\n    '''Evaluate the model for given input and result.\n    Input is given as a dictionary of lists to pass to session.run'''\n    phs = model_fn[\"placeholders\"]\n    if phs:\n        k0 = next (iter (phs.keys())) # 1st placeholder\n        total_len = len(xs[k0]) # total length\n    else:\n        total_len = 1\n    zeros = dict((k,tf.zeros(phs[k][\"shape\"][1:], # remove the batch size\n                             dtype=phs[k][\"dtype\"])) for k in phs.keys())\n    results = []\n    if model_fn[\"batched\"]:\n        def run():\n          bs = model_static[\"batch_size\"]\n          for i in range(0, bs*(-(-total_len//bs)), bs):\n              print(\".\",end=\"\")\n              chunks = dict()\n              for k in phs:\n                  chunks[k] = xs[k][i:i+bs]\n              if i + bs > total_len:\n                  # dealing with an incomplete last chunk\n                  origLen = total_len - i\n                  for k in chunks:\n                      chunks[k] = list(chunks[k]) + [zeros[k]] * (bs - origLen)  # pad the last chunk\n              else:\n                  origLen = bs\n              chunks = {k: tf.cast(v,dtype=phs[k][\"dtype\"]) for (k,v) in chunks.items()}\n              results = model_fn[\"function\"](tf.constant(False, shape=[]), **{**(model_static[\"paramsdict\"]), **chunks}) \n              yield results[result][:origLen]\n        return np.concatenate(list(run()))\n    else:\n        def run():\n            for i in range(total_len):\n                inputs = {k: tf.cast(xs[k][i], dtype=phs[k][\"dtype\"]) for k in phs}\n                results = model_fn[\"function\"](tf.constant(False, shape=[]), **{**(model_static[\"paramsdict\"]), **inputs})\n                yield results[result]\n        return list(run())\n        \n\npredict = evaluate\n\ndef beam_translate(session, model, k, x, xlen, start_symbol, stop_symbol, debug=None):\n    '''Beam translation of ONE input sentence.'''\n    (_,out_len,voc_size) = model[\"y_\"].shape\n    xs = np.array ([x] * k) # The input is always the same\n    xs_len = np.array ([xlen]*k) # it is VERY important to get the length right\n    ys = [[start_symbol]]  # start with a single thing; otherwise the minimum will be repeated k times\n    probs = [1]\n    results = []\n    hist = [[]]\n    def pad(z):\n        return np.array(z + [0] * (out_len - len(z)))\n    for i in range(out_len-1):\n        print (\"beam search at:\", i)\n        inputs = {\"src_len\":xs_len[:len(ys)], \"src_in\":xs[:len(ys)], \"tgt_in\":np.array([pad(y) for y in ys])}\n        y_s = predict(session,model,inputs)\n        all_words = sorted([(y_s[j][i][w] * probs[j], ys[j] + [w], hist[j] + [y_s[j][i][w]])\n                            for j in range(len(y_s))\n                            for w in range(voc_size)])\n        best = all_words[-k:]\n        if debug is not None:\n            for x in best: debug(x)\n        results  += [(p,y,h) for (p,y,h) in best if y[i+1] == stop_symbol]\n        continued = [(p,y,h) for (p,y,h) in best if y[i+1] != stop_symbol]\n        if len(continued) == 0: break\n        (probs,ys,hist) = zip(*continued)\n    return sorted(results)\n\n######################################################\n# Saving and loading\n\ndef save(model_static, file):\n  numpy_tensors = {k:v.numpy() for (k,v) in model_static[\"paramsdict\"].items()}\n  print(\"Saving parameters: \", model_static[\"paramsdict\"].keys())\n  np.savez(file,**numpy_tensors)\n  print(\"Done\")\n\n\ndef load(model_static, file):\n  print(\"Loading parameters\")\n  numpy_tensors = np.load(file)\n  print(\"Loaded parameters: \", list(numpy_tensors.keys()))\n  for k,v in model_static[\"paramsdict\"].items():\n      v.assign(numpy_tensors[k])\n  print(\"Done\")\n"
  }
]