Full Code of GU-CLASP/TypedFlow for AI

master 3a8fa230d413 cached
39 files
272.5 KB
86.0k tokens
35 symbols
1 requests
Download .txt
Showing preview only (286K chars total). Download the full file or copy to clipboard to get everything.
Repository: GU-CLASP/TypedFlow
Branch: master
Commit: 3a8fa230d413
Files: 39
Total size: 272.5 KB

Directory structure:
gitextract_rnbmpni2/

├── .gitignore
├── LICENSE
├── Makefile
├── README.org
├── TypedFlow/
│   ├── Abstract.hs
│   ├── Broadcast.hs
│   ├── Haskell.hs
│   ├── Layers/
│   │   ├── Core.hs
│   │   ├── RNN/
│   │   │   ├── Attention.hs
│   │   │   ├── Base.hs
│   │   │   └── Cells.hs
│   │   └── RNN.hs
│   ├── Layers.hs
│   ├── Learn.hs
│   ├── Memo.hs
│   ├── Memo2.hs
│   ├── Models/
│   │   ├── Topic.hs
│   │   └── Transformer.hs
│   ├── Python.hs
│   ├── TF.hs
│   ├── Types/
│   │   └── Proofs.hs
│   └── Types.hs
├── TypedFlow.hs
├── cabal.project
├── docs/
│   ├── HOT.org
│   └── Talk.org
├── examples/
│   ├── agreement/
│   │   └── Aggr.hs
│   ├── mnist/
│   │   ├── MNIST.hs
│   │   ├── Makefile
│   │   ├── main.py
│   │   └── mnist_model.py
│   └── seq2seq/
│       ├── GenTr.hs
│       ├── Makefile
│       ├── Seq2Seq.hs
│       ├── main.py
│       └── shell.nix
├── styx.yaml
├── typedflow.cabal
└── typedflow_rts.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.styx
*~
dist
dist-*
cabal-dev
*.o
*.hi
*.chi
*.chs.h
*.dyn_o
*.dyn_hi
.hpc
.hsenv
.cabal-sandbox/
cabal.sandbox.config
*.prof
*.aux
*.hp
*.eventlog
.stack-work/
cabal.project.local
.HTF/
/examples/seq2seq/s2s.py
/examples/seq2seq/synthtrees.txt
MNIST_data
__pycache__
/examples/seq2seq/GenTr
/.tramp_history


================================================
FILE: LICENSE
================================================
                   GNU LESSER GENERAL PUBLIC LICENSE
                       Version 3, 29 June 2007

 Copyright (C) 2007 Free Software Foundation, Inc. <http://fsf.org/>
 Everyone is permitted to copy and distribute verbatim copies
 of this license document, but changing it is not allowed.


  This version of the GNU Lesser General Public License incorporates
the terms and conditions of version 3 of the GNU General Public
License, supplemented by the additional permissions listed below.

  0. Additional Definitions.

  As used herein, "this License" refers to version 3 of the GNU Lesser
General Public License, and the "GNU GPL" refers to version 3 of the GNU
General Public License.

  "The Library" refers to a covered work governed by this License,
other than an Application or a Combined Work as defined below.

  An "Application" is any work that makes use of an interface provided
by the Library, but which is not otherwise based on the Library.
Defining a subclass of a class defined by the Library is deemed a mode
of using an interface provided by the Library.

  A "Combined Work" is a work produced by combining or linking an
Application with the Library.  The particular version of the Library
with which the Combined Work was made is also called the "Linked
Version".

  The "Minimal Corresponding Source" for a Combined Work means the
Corresponding Source for the Combined Work, excluding any source code
for portions of the Combined Work that, considered in isolation, are
based on the Application, and not on the Linked Version.

  The "Corresponding Application Code" for a Combined Work means the
object code and/or source code for the Application, including any data
and utility programs needed for reproducing the Combined Work from the
Application, but excluding the System Libraries of the Combined Work.

  1. Exception to Section 3 of the GNU GPL.

  You may convey a covered work under sections 3 and 4 of this License
without being bound by section 3 of the GNU GPL.

  2. Conveying Modified Versions.

  If you modify a copy of the Library, and, in your modifications, a
facility refers to a function or data to be supplied by an Application
that uses the facility (other than as an argument passed when the
facility is invoked), then you may convey a copy of the modified
version:

   a) under this License, provided that you make a good faith effort to
   ensure that, in the event an Application does not supply the
   function or data, the facility still operates, and performs
   whatever part of its purpose remains meaningful, or

   b) under the GNU GPL, with none of the additional permissions of
   this License applicable to that copy.

  3. Object Code Incorporating Material from Library Header Files.

  The object code form of an Application may incorporate material from
a header file that is part of the Library.  You may convey such object
code under terms of your choice, provided that, if the incorporated
material is not limited to numerical parameters, data structure
layouts and accessors, or small macros, inline functions and templates
(ten or fewer lines in length), you do both of the following:

   a) Give prominent notice with each copy of the object code that the
   Library is used in it and that the Library and its use are
   covered by this License.

   b) Accompany the object code with a copy of the GNU GPL and this license
   document.

  4. Combined Works.

  You may convey a Combined Work under terms of your choice that,
taken together, effectively do not restrict modification of the
portions of the Library contained in the Combined Work and reverse
engineering for debugging such modifications, if you also do each of
the following:

   a) Give prominent notice with each copy of the Combined Work that
   the Library is used in it and that the Library and its use are
   covered by this License.

   b) Accompany the Combined Work with a copy of the GNU GPL and this license
   document.

   c) For a Combined Work that displays copyright notices during
   execution, include the copyright notice for the Library among
   these notices, as well as a reference directing the user to the
   copies of the GNU GPL and this license document.

   d) Do one of the following:

       0) Convey the Minimal Corresponding Source under the terms of this
       License, and the Corresponding Application Code in a form
       suitable for, and under terms that permit, the user to
       recombine or relink the Application with a modified version of
       the Linked Version to produce a modified Combined Work, in the
       manner specified by section 6 of the GNU GPL for conveying
       Corresponding Source.

       1) Use a suitable shared library mechanism for linking with the
       Library.  A suitable mechanism is one that (a) uses at run time
       a copy of the Library already present on the user's computer
       system, and (b) will operate properly with a modified version
       of the Library that is interface-compatible with the Linked
       Version.

   e) Provide Installation Information, but only if you would otherwise
   be required to provide such information under section 6 of the
   GNU GPL, and only to the extent that such information is
   necessary to install and execute a modified version of the
   Combined Work produced by recombining or relinking the
   Application with a modified version of the Linked Version. (If
   you use option 4d0, the Installation Information must accompany
   the Minimal Corresponding Source and Corresponding Application
   Code. If you use option 4d1, you must provide the Installation
   Information in the manner specified by section 6 of the GNU GPL
   for conveying Corresponding Source.)

  5. Combined Libraries.

  You may place library facilities that are a work based on the
Library side by side in a single library together with other library
facilities that are not Applications and are not covered by this
License, and convey such a combined library under terms of your
choice, if you do both of the following:

   a) Accompany the combined library with a copy of the same work based
   on the Library, uncombined with any other library facilities,
   conveyed under the terms of this License.

   b) Give prominent notice with the combined library that part of it
   is a work based on the Library, and explaining where to find the
   accompanying uncombined form of the same work.

  6. Revised Versions of the GNU Lesser General Public License.

  The Free Software Foundation may publish revised and/or new versions
of the GNU Lesser General Public License from time to time. Such new
versions will be similar in spirit to the present version, but may
differ in detail to address new problems or concerns.

  Each version is given a distinguishing version number. If the
Library as you received it specifies that a certain numbered version
of the GNU Lesser General Public License "or any later version"
applies to it, you have the option of following the terms and
conditions either of that published version or of any later version
published by the Free Software Foundation. If the Library as you
received it does not specify a version number of the GNU Lesser
General Public License, you may choose any version of the GNU Lesser
General Public License ever published by the Free Software Foundation.

  If the Library as you received it specifies that a proxy can decide
whether future versions of the GNU Lesser General Public License shall
apply, that proxy's public statement of acceptance of any version is
permanent authorization for you to choose that version for the
Library.


================================================
FILE: Makefile
================================================

viewdoc: dist/doc/html/typedflow/index.html
	xdg-open $<

dist/doc/html/typedflow/index.html:
	styx cabal -- haddock --hyperlink-source
	styx cabal -- hscolour



================================================
FILE: README.org
================================================
#+TITLE: TypedFlow

TypedFlow is a typed, higher-order frontend to [[http://www.tensorflow.org][TensorFlow]] and a
high-level library for deep-learning.

The main design principles are:

  - To make the parameters of layers explicit. This choice makes
    sharing of parameters explicit and allows to implement "layers" as
    pure functions.

  - To provide as precise as possible types. Functions are explicit
    about the shapes and elements of the tensors that they manipulate
    (they are often polymorphic in shapes and elements though.)

  - To let combinators be as transparent as possible. If a NN layers
    is a simple tensor transformation it will be exposed as such.


In this version, the interface to TensorFlow is done via python-code
generation and a suitable runtime system.

** Documentation

The compiled documentation should be found on [[https://hackage.haskell.org/package/typedflow][hackage]].

** Examples

TypedFlow comes with two examples of neural networks:

 - An adaptation of the [[examples/mnist][MNIST tensorflow tutorial]]
 - A simple [[examples/seq2seq][sequence to sequence model]] which
   attempts to learn to translate pre-order into post-order.

To running the examples can be done like so:

#+BEGIN_SRC shell
nix-env -iA nixpkgs.haskellPackages.styx
nix-env -iA nixpkgs.cabal2nix
styx configure
cd examples/seq2seq
make
#+END_SRC



================================================
FILE: TypedFlow/Abstract.hs
================================================
{-# LANGUAGE InstanceSigs #-}
{-|
Module      : TypedFlow.Abstract
Description : Abstract Tensor representations
Copyright   : (c) Jean-Philippe Bernardy, 2018
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental

This module provides operations on the abstract representation of
tensor operations. It is not normally imported directly by users.
-}

{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE CPP #-}
#if __GLASGOW_HASKELL__ >= 806
{-# LANGUAGE NoStarIsType #-}
#endif

module TypedFlow.Abstract where

import Control.Monad.RWS (RWS, tell, runRWS)
import Control.Monad.State
import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
import Prelude hiding (RealFrac(..))
import qualified TypedFlow.Memo as Memo0
import TypedFlow.Types
import TypedFlow.Broadcast
import TypedFlow.Types.Proofs

freeVarsT :: forall s t. KnownTyp t => KnownShape s
  => T s t -> [Int]
freeVarsT x = result
  where f :: forall s' t'. T s' t' -> [Int]
        f = Memo0.memo (protoFreevars f)
        result = f x

protoFreevars :: (forall s' t'. T s' t' -> [Int]) -> T s t -> [Int]
protoFreevars rec = \case
  BroadcastT _ _ _ _ x -> rec x
  MapT _ s f x -> rec x <> rec (f (T (Variable (Ref (-789) s typeSTyp))))
  Softmax _ _ x -> rec x
  DirectBroadcast _ _ _ _ x -> rec x
  GatherND _ _ _ x y -> rec x <> rec y
  Noise _ _ _ _ -> []
  Where cond x y -> rec cond <> rec x <> rec y
  If cond x y ->  rec cond <> rec x <> rec y
  T (Variable (Ref i _ _)) -> [i]
  T _ -> []
  Unbroadcast _p _u x -> rec x
  UnOp _op _ x -> rec x
  MatMul _ _ _ _ x y -> rec x <> rec y
  BinOp _op _ _ _ _ _ x y -> rec x <> rec y
  Gather _is _s0 _m _s1 x ix -> rec x <> rec ix
  Transpose _ _t x -> rec x
  ReshapeFrom _s x -> rec x
  Concat _s0  _s1 xs -> mconcat $ htoList $ hmap (\(Catable _ x) -> K (rec x)) xs
  Convolution _bs _inChans _outChans _filterShape _s x filters -> rec x <> rec filters
  Pool _ _ _ _ _ x  -> rec x
  _ -> error "protoFreevars: unhandled case"




genTrainingPlaceholder :: Scalar TFBool
genTrainingPlaceholder = T (ExternalVar (Ref "training_placeholder" typeSShape typeSTyp))


-- | Zeros
zeros :: ∀ t (shape :: Shape). KnownNumeric t => KnownShape shape => (T shape t)
zeros = constant $ knownNum @t $ 0

defaultT :: ∀ t (shape :: Shape). KnownShape shape => KnownTyp t => (T shape t)
defaultT = case typeSTyp @t of
                 STyp SFloat _ _ -> zeros
                 STyp SInt _ _ -> zeros
                 STyp SBool _ _ -> constant False
                 _ -> error "defaultT: unhandled case"


-- | Ones
ones :: ∀ t (shape :: Shape). KnownShape shape => KnownNumeric t => (T shape t)
ones = knownNum @t $ constant 1

-- | Identity matrix in dimensions n,n
eye :: ∀ n t. KnownNat n => KnownNumeric t => (T '[n,n] t)
eye = diag 1

diag :: ∀ n t. KnownTyp t => KnownNat n => T '[n] t ->  T '[n,n] t
diag = UnOp (Diag Sat) Unit

expm :: ∀ n t. KnownNat n => KnownNumeric t => T '[n,n] t ->  T '[n,n] t
expm = UnOp (ExpM Sat) Unit

-- | @k@=diagonal above which to zero elements. k = 0 is the main diagonal, k < 0 is below it and k > 0 is above.
tril :: ∀ n t. KnownNat n => KnownNumeric t => Integer -> T '[n,n] t ->  T '[n,n] t
tril k = UnOp (ZeroTriangle Sat Lower k) Unit

triu :: ∀ n t. KnownNat n => KnownNumeric t => Integer -> T '[n,n] t ->  T '[n,n] t
triu k = UnOp (ZeroTriangle Sat Upper k) Unit


-- | Constant
constant :: forall s t w. KnownShape s => KnownBits w => KnownKind t => HaskType ('Typ t w) -> T s ('Typ t w)
constant c = appRUnit @s #> broadcastTT @s (scalar c)

scalar :: forall t w. KnownBits w => KnownKind t => HaskType ('Typ t w) -> Scalar ('Typ t w)
scalar = T . Constant

reduceAll :: forall s t. KnownTyp t => KnownShape s =>
     (∀n s'. (KnownTyp t,KnownShape s') => Axis n s' -> T s' t -> T (Take n s' ++ Drop ('Succ n) s') t) -> Tensor s t -> Tensor '[] t
reduceAll op x = knownProduct @s ?>
   op axis0 (reshapeTo ((:*) (productS (typeSShape @s)) Unit) x)

-- | Mean value of the input tensor.
reduceMeanAll, reduceSumAll, reduceMaxAll, reduceMinAll :: ∀ (s :: Shape) t. KnownNumeric t => KnownShape s => Tensor s t -> Tensor '[] t
reduceMaxAll = reduceAll reduceMax
reduceMeanAll = reduceAll reduceMean
reduceSumAll = reduceAll reduceSum
reduceMinAll = reduceAll reduceMin

sShapeTake' :: Axis n s -> SList' f s -> SList' f (Take n s)
sShapeTake' AxZero _s = Unit
sShapeTake' (AxSucc n) ((:*) x xs) = (:*) x (sShapeTake' n xs)

sShapeDrop' :: Axis n s -> SList' f s -> SList' f (Drop n s)
sShapeDrop' AxZero s = s
sShapeDrop' (AxSucc n) ((:*) _ xs) = sShapeDrop' n xs

sShapeDropSucc :: Axis n s -> SList' f s -> SList' f (Drop ('Succ n) s)
sShapeDropSucc AxZero (_ :* s) = s
sShapeDropSucc (AxSucc n) (_ :* xs) = sShapeDropSucc n xs

-- | Internal. Use 'reduceSum', etc. instead.
reduce :: ∀ n s t. KnownNumeric t => (KnownShape s) => ReduceOp -> Axis n s -> T s t -> T (Take n s ++ Drop ('Succ n) s) t
reduce op n x = case axisSplitApp' n of
  Refl -> UnOp (Axis1Op (sShapeDropSucc n s) (ReduceOp (hlookup n s) op)) (sShapeTake' n s) x
 where s = typeSShape @s

-- | Reduce along a given dimension
reduceSum, reduceMean, reduceMax, reduceMin :: ∀n s t. (KnownNumeric t,KnownShape s) => Axis n s -> T s t -> T (Take n s ++ Drop ('Succ n) s) t
reduceSum = reduce Sum
reduceMean = reduce Mean
reduceMax = reduce Max
reduceMin = reduce Min


-- | Sum along the first dimension
reduceSum0 :: ∀ s' n t. KnownNat n => KnownNumeric t => KnownShape s' => Tensor (n ': s') t -> Tensor s' t
reduceSum0 = reduceSum axis0



addN :: ∀ s t. KnownNumeric t => KnownShape s => [Tensor s t] -> Tensor s t
addN [] = zeros
addN ts = foldr1 (+) ts

instance (KnownNumeric t, KnownShape s) => Num (T s t) where
  (+) = (⊕)
  (*) = (⊙)
  signum = unOp Sign
  fromInteger x = knownNum @t $ constant (fromIntegral x)
  abs = unOp Abs
  (-) = (⊝)
  negate = unOp Negate

instance (KnownFloat b, KnownShape s) => Fractional (T s b) where
  fromRational x = knownAlgebraic @b $ constant (fromRational x :: HaskType b)
  (/) = (⊘)

instance (KnownFloat b, KnownShape s) => Floating (T s b) where
  pi = knownAlgebraic @b $ constant pi
  exp = unFlOp Exp
  log = unFlOp Log
  sin = unFlOp Sin
  cos = unFlOp Cos
  asin = unFlOp Asin
  acos = unFlOp Acos
  sinh = unFlOp Sinh
  cosh = unFlOp Cosh
  asinh = unFlOp Asinh
  acosh = unFlOp Acosh
  tanh = unFlOp Tanh
  atan = unFlOp Atan
  atanh = unFlOp Atanh
  sqrt = unFlOp Sqrt

-- | Pretend that the argument is a constant for the purposes of
-- gradient computation
stopGradient :: ∀ s t. KnownTyp t => KnownShape s => Tensor s t -> Tensor s t
stopGradient = appRUnit @s #> UnOp StopGradient (typeSShape @s)

-- | Divide tensors, broacasting along shape @s@
(⊘) :: forall s t. KnownAlgebraic t => KnownShape s => T s t -> T s t -> T s t
(⊘) = binOp Divide

-- | Divide tensors, broacasting along shape @s@
floorDiv :: forall s w. KnownBits w => KnownShape s => T s ('Typ 'Int w) -> T s ('Typ 'Int w) -> T s ('Typ 'Int w)
floorDiv = binOp IntegerDiv


-- | Indexwise equality test.
equal :: forall s t. (KnownShape s, KnownTyp t) => Tensor s t -> Tensor s t -> Tensor s TFBool
equal = binOp (Equal)

-- | Indexwise operator
(⊕), (⊝), (⊙)  :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s t
(⊝) = binOp Subtract
(⊙) = binOp Multiply
(⊕) = binOp Add

maxT,minT :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s t
maxT = binOp Maximum
minT = binOp Minimum

mkComplex :: KnownBits w => KnownShape s => Tensor s (Flt w) -> Tensor s (Flt w) -> Tensor s ('Typ 'Cmplx w)
mkComplex = binOp MkComplex

lessThan :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s TFBool
lessThan = binOp (Comparision Less)

lessOrEqualThan :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s TFBool
lessOrEqualThan = binOp (Comparision LessOrEqual)

greaterThan :: ∀ (s :: Shape) t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s TFBool
greaterThan = binOp (Comparision Greater)

logicAnd :: ∀ (s :: Shape). (KnownShape s) => Tensor s TFBool -> Tensor s TFBool-> Tensor s TFBool
logicAnd = binOp (Logic And)


infixl 7 ⊙,⊘
infixl 6 ⊕,⊝


-- | Matrix multiplication (note that shape @s@ is preserved)
matmul :: forall m n o t. KnownNumeric t => KnownNat m => KnownNat o => KnownNat n => KnownTyp t => T '[n,o] t -> T '[o,m] t -> T '[n,m] t
matmul = MatMul Unit Sat Sat Sat

unOp :: forall s t. KnownShape s => KnownNumeric t => Num1Op -> T s t -> T s t
unOp op = appRUnit @s #> UnOp (Num1Op op)  (typeSShape @s)

unFlOp :: forall s t. KnownBits t => KnownShape s => Float1Op -> T s (Flt t) -> T s (Flt t)
unFlOp op = appRUnit @s #> UnOp (Float1Op op) (typeSShape @s)

binOp :: forall s t u. KnownShape s => KnownTyp t => Simple2Op t u -> T s t -> T s t -> T s u
binOp op = appRUnit @s #> BinOp (Simple2Op op) (typeSShape @s) Unit typeSTyp Unit typeSTyp

conjugate :: ∀ s w. KnownShape s => KnownBits w => T s ('Typ 'Cmplx w) ->  T s ('Typ 'Cmplx w)
conjugate = appRUnit @s #> UnOp Conjugate (typeSShape @s)

realPart :: ∀ s w. KnownShape s => KnownBits w => T s ('Typ 'Cmplx w) ->  T s ('Typ 'Float w)
realPart = appRUnit @s #> UnOp RealPart (typeSShape @s)

sigmoid, relu, square, round, floor, hardSigmoid
   :: ∀ s t. (KnownShape s, KnownBits t)
   => Tensor s ('Typ 'Float t) -> Tensor s ('Typ 'Float t)
sigmoid = unFlOp Sigmoid
hardSigmoid = unFlOp HardSigmoid
square = unOp Square
relu = unFlOp Relu

floorMod :: ∀ s t. (KnownShape s, KnownNumeric t) => Tensor s t -> Tensor s t -> Tensor s t
floorMod = binOp FloorMod

-- Unfortunately RealFrac is utterly broken; so we have to do this:
round = unFlOp Round
floor = unFlOp Floor

-- | Take a slice at dimension n from i to j.
slice :: forall i j s t n. KnownTyp t => KnownShape s => KnownNat j => KnownNat i => (i <= j, j <= At n s, KnownLen s) =>
         Axis n s -> Tensor s t -> Tensor (Take n s ++ ((j-i) ': Drop ('Succ n) s)) t
slice n = case axisSplitApp' n of
  Refl -> UnOp (Axis1Op (sShapeDropSucc n s) (SliceOp (Proxy @(j-i)) (hlookup n s) (natVal (Proxy @i)) (natVal (Proxy @j))))
               (sShapeTake' n s)
 where s = typeSShape @s


slice1 :: forall i j m n s t. KnownShape s => KnownNat m => KnownNat n => KnownTyp t => KnownNat j => KnownNat i => (i <= j, j <= m, KnownLen s) =>
         Tensor (n ': m ': s) t -> Tensor (n ': (j-i) ': s) t
slice1 = slice @i @j axis1

slice0 :: forall i j m s t. KnownShape s => KnownNat m => KnownTyp t => KnownNat j => KnownNat i => (i <= j, j <= m, KnownLen s) =>
         Tensor (m ': s) t -> Tensor ((j-i) ': s) t
slice0 = slice @i @j axis0



-- | Concatenate tensors on dimension @n@. Recommended: use @zipWithTT (concat0 ...)@ instead.
concatT :: ∀ n d1 d2 s t. KnownNat d2 => KnownNat d1 => KnownShape s => (KnownTyp t, (d1+d2) ~ At n s) =>
    Axis n s -> T (Take n s ++ (d1 ': Drop ('Succ n) s)) t -> T (Take n s ++ (d2 ': Drop ('Succ n) s)) t -> T s t
concatT n = case axisSplitApp' n of Refl -> concatT' (sShapeTake' n s) d1 d2 (sShapeDropSucc n s)
  where s = typeSShape @s; d1 = natSat @d1; d2 = natSat @d2

-- | Concatenate tensors on the first dimension
concat0, (©) :: ∀ d1 d2 ys t. KnownTyp t => KnownShape ys => KnownNat d2 => KnownNat d1 => (KnownLen ys) => T (d1 ': ys) t -> T (d2 ': ys) t -> T ((d1 + d2) ': ys) t
concat0 = concatT axis0

(©) = concat0

-- | Concatenate tensors on the second dimension
concat1 :: ∀ n ys d1 d2 t. KnownShape ys => KnownNat n => KnownNat d2 => KnownNat d1 => KnownTyp t => (KnownLen ys) =>  T (n ': d1 ': ys) t -> T (n ': d2 ': ys) t -> T (n ': (d1 + d2) ': ys) t
concat1 = concatT axis1

-- | Add an extra dimension at axis (@n@) of size 1.
expandDim :: forall n s t. KnownTyp t => KnownShape s => (PeanoNat n <= Length s) => Tensor s t -> Tensor (Take n s ++ (1 ': Drop n s)) t
expandDim x =
  -- Product (Take n s ++ (1 ': Drop n s))
  prodHomo @(Take n s) @(1' : Drop n s) #>
  -- Product (Take n s) * Product (Drop n s)
  prodHomo @(Take n s) @(Drop n s) #>
  -- Product (Take n s ++ (1 ': Drop n s))
  takeDrop @s @n #>
  -- Product s
  reshapeFrom (typeSShape @s) x

-- +expandDim :: forall n s t. KnownTyp t => KnownShape s => Axis n s -> Tensor s t -> Tensor (Take n s ++ (1 ': Drop n s)) t
-- +expandDim ax x = case expandDimProof ax s of Refl -> reshapeFrom s x

-- | Add an extra dimension at axis (0) of size 1.
expandDim0 :: ∀ s t. KnownShape s => KnownTyp t => KnownLen s => Tensor s t -> Tensor (1 ': s) t
expandDim0 = reshape

-- | Add an extra dimension at axis (1) of size 1.
expandDim1 :: ∀ n s t. KnownNat n => KnownTyp t => KnownShape s => Tensor (n ': s) t -> Tensor (n ': 1 ': s) t
expandDim1 = reshape


-- | Flatten all the dimensions of the tensor
flattenAll :: forall s t. KnownTyp t => KnownShape s => Tensor s t -> Tensor '[Product s] t
flattenAll = knownProduct @s ?> reshape

inflateAll :: forall s t. KnownTyp t => KnownShape s => Tensor '[Product s] t -> Tensor s t
inflateAll = knownProduct @s ?> reshape


squeeze0 :: ∀ s t. KnownTyp t => (KnownShape s) => Tensor (1 ': s) t -> Tensor s t
squeeze0 = reshape

atShape :: SList s -> T s t -> T s t
atShape _ x = x

-- | Reshape a tensor so that the last two dimensions are collapsed
flattenN2 :: ∀ s m n t. KnownTyp t => (KnownNat m, KnownNat n, KnownShape s) => Tensor (s ++ '[m,n]) t -> Tensor (s ++ '[m*n]) t
flattenN2  = prodHomo @s @'[m,n] #>
             prodHomo @s @'[m*n] #>
             knownAppend @s @'[m*n] ?>
             knownAppend @s @'[m,n] ?>
             reshape

-- | Reshape a tensor so that the first three dimensions are collapsed
flatten3 :: ∀ m n o s t. KnownTyp t => (KnownNat m, KnownNat n, KnownNat o, KnownShape s) => Tensor (m ': n ': o ': s) t -> Tensor (m*n*o ': s) t
flatten3  =  -- (m * (n * (o * Product s)))
             prodAssoc @m @n @(o * Product s) #>
             -- (m * n) * (o * Product s)
             prodAssoc @(m * n) @o @(Product s) #>
             -- ((m * n) * o) * Product s
             reshape

-- | Reshape a tensor so that the first two dimensions are collapsed
flatten12 :: ∀ m n o s t. KnownTyp t => KnownNat o => (KnownNat m, KnownNat n, KnownShape s) => Tensor (o ': m ': n ': s) t -> Tensor (o ': m*n ': s) t
flatten12 = prodAssoc @m @n @(Product s) #> reshape

-- | Reshape a tensor so that the first dimension is expanded into three.
inflate3 :: ∀ m n o s t. KnownTyp t => (KnownNat m, KnownNat n, KnownNat o, KnownShape s) => Tensor (m*n*o ': s) t -> Tensor (m ': n ': o ': s) t
inflate3 = -- (m * (n * (o * Product s)))
           prodAssoc @m @n @(o * Product s) #>
           -- (m * n) * (o * Product s)
           prodAssoc @(m * n) @o @(Product s) #>
           -- ((m * n) * o) * Product s
           reshape

-- | Reshape a tensor so that the first two dimensions are collapsed
inflate12 :: ∀ m n o s t. KnownTyp t => KnownNat o => (KnownNat m, KnownNat n, KnownShape s) => Tensor (o ': m*n ': s) t -> Tensor (o ': m ': n ': s) t
inflate12 = prodAssoc @m @n @(Product s) #> reshape


-- | Access the last element in a tensor (in the 0th dimension)
last0 :: ∀ n s t. KnownShape s => KnownTyp t => KnownNat n => KnownLen s => T (n ': s) t -> Tensor s t
last0 = nth0 (natVal (Proxy @n) - 1)

-- | Access the nth element in a tensor (in the 0th dimension)
nth0 :: ∀ n s t. KnownTyp t => KnownNat n => KnownShape s => Integer -> T (n ': s) t -> Tensor s t
nth0 i x = UnOp (Axis1Op (typeSShape @s) (AccessOp (natSat @n) i)) Unit x

-- | Access the nth element in a tensor (in the 0th dimension), with a static index
nth0' :: ∀ n m s t. KnownNat m => KnownTyp t => KnownShape s => KnownNat n => KnownLen s => n < m => T (m ': s) t -> Tensor s t
nth0' = nth0 (natVal (Proxy @n))

vecToNP :: forall a f n k. (a -> f 1) -> V n a -> (forall xs. Sum xs ~ n => NP f xs -> k) -> k
vecToNP _f VUnit k = k Unit
vecToNP f (x :** xs) k = vecToNP f xs $ \xs' -> k (f x :* xs')

stackT :: ∀ s0 s (n::Nat) t. KnownShape s => KnownShape s0 => KnownNat n => (KnownLen s0) => V n (T (s0 ++ s) t) -> Tensor (s0 ++ (n ': s)) t
stackT v = vecToNP @(T (s0++s) t) @(Catable s0 s t)
             (\x -> (Catable (natSat @1) $ (prodHomoS s0 s #>
                                            prodHomoS s0 (natSat @1 :* s) #>
                                            knownAppend @s0 @s ?>
                                            knownSShape (s0 .+. natSat @1 :* s) ?>
                                            reshape x)))
             v $ (Concat (typeSShape @s0)  (typeSShape @s)) 
  where s = typeSShape @s; s0 = typeSShape @s0


-- | Concatenate @n@ tensors along the first dimension
stack0 :: ∀ s (n::Nat) t. KnownNat n => KnownShape s => (KnownLen s) => V n (T s t) -> Tensor (n ': s) t
stack0 = stackT @'[]

-- | Concatenate @n@ tensors along the second dimension
stack1 :: ∀ s (n::Nat) m t. KnownNat n => KnownNat m => KnownShape s => (KnownLen s) => V n (T (m ': s) t) -> Tensor (m ': n ': s) t
stack1 = stackT @'[m]

-- | Concatenate @n@ tensors along the last dimension
stackN :: ∀ s (n::Nat) t. KnownNat n => KnownShape s => V n (T s t) -> Tensor (s ++ '[n]) t
stackN = appRUnit @s #>
         stackT @s @'[]


-- | Split a tensors into @n@ tensors along the first dimension
unstack0 :: ∀ s (n::Nat) t. KnownTyp t => KnownNat n => KnownShape s => (KnownLen s) => Tensor (n ': s) t -> V n (T s t)
unstack0 x = fmap (`nth0` x) (vcount @n)

-- | Stack a tensor vector. (To be used on literal lists of tensors.)
litStack0 :: KnownShape s => KnownLen xs => TV s t xs -> Tensor (Length xs ': s) t
litStack0 tv = knownSList tv ?> stack0 $ toV tv
  where toV :: TV s t xs -> V (Length xs) (T s t)
        toV Unit = VUnit
        toV (K x :* xs) = x :** toV xs

-- | Generate a mask of given length for each sequence.
sequenceMask :: forall maxlen. KnownNat maxlen => Tensor '[] Int32 -> Tensor '[maxlen] TFBool
sequenceMask lens = mapT (lens `lessThan`) (range @maxlen)

-- | simple broadcasting of a tensor (like a zero-arity map)
broadcastT :: forall n s t. KnownShape s => KnownNat n => KnownTyp t => KnownLen s => T s t ->  T (n ': s) t
broadcastT x = BroadcastT Nothing False (natSat @n) typeSShape x

-- | simple broadcasting of a tensor
broadcastTT :: forall a s t. KnownShape s => KnownTyp t => KnownShape a => KnownLen s => T s t ->  T (a ++ s) t
broadcastTT x = prodHomo @a @s #>
                knownProduct @a ?>
                knownAppend @a @s ?>
                reshape (broadcastT @(Product a) x)



-- | Map a function along the first dimension of a tensor
mapT :: forall n s r t u. KnownShape s => KnownNat n => KnownTyp t => KnownLen r => KnownLen s
     => (T s t -> T r u) ->  T (n ': s) t -> T (n ': r) u
mapT f x = MapT Sat typeSShape f x

-- | Map a function along the few first dimensions of a tensor, given by the first type parameter
mapTT :: forall a s t r u. KnownShape r => KnownShape a => KnownTyp u => KnownLen r => KnownShape s => KnownTyp t
  => (T s t -> T r u) ->  T (a ++ s) t -> T (a ++ r) u
mapTT f x = prodHomo @a @r #>
            prodHomo @a @s #>
            knownProduct @a ?>
            knownAppend @a @r ?>
            knownAppend @a @s ?>
            reshape (mapT @(Product a) f (reshape x))

-- | zip  a function along the first dimension of two tensors tensors
zipWithT :: forall (n :: Nat) (s :: [Nat]) (t :: Typ) (s1 :: [Nat]) (t1 :: Typ) (s2 :: Shape)  (t2 :: Typ).
            KnownShape s => KnownShape s1 => KnownNat n=> KnownTyp t => KnownTyp t1
            => (T s t -> T s1 t1 -> T s2 t2)
            -> Tensor (n ': s) t
            -> Tensor (n ': s1) t1
            -> Tensor (n ': s2) t2
zipWithT f x y = ZipT Sat typeSShape typeSShape f x y

-- | zip  a function along the few first dimensions of a tensor, given by the first type parameter
zipWithTT :: forall a (s :: [Nat]) (s1 :: [Nat]) (s2 :: Shape) (t :: Typ) (t1 :: Typ)  (t2 :: Typ).
            KnownTyp t1 => KnownTyp t => KnownShape s => KnownShape s1 => KnownShape a => KnownShape s2 => KnownTyp t2
            => (T s t -> T s1 t1 -> T s2 t2)
            -> Tensor (a ++ s) t
            -> Tensor (a ++ s1) t1
            -> Tensor (a ++ s2) t2
zipWithTT f x y = 
            prodHomo @a @s1 #>
            prodHomo @a @s2 #>
            prodHomo @a @s #>
            knownProduct @a ?>
            knownAppend @a @s1 ?>
            knownAppend @a @s2 ?>
            knownAppend @a @s ?>
            reshape (zipWithT @(Product a) f (reshape x) (reshape y))

zipWith3T :: forall (n :: Nat) (s :: [Nat]) (t :: Typ) (s1 :: [Nat]) (t1 :: Typ) (s2 :: Shape)  (t2 :: Typ) (s3 :: Shape)  (t3 :: Typ).
             KnownShape s => KnownShape s1 => KnownShape s2 => KnownShape s3 => KnownNat n => KnownTyp t3 => KnownTyp t => KnownTyp t1 => KnownTyp t2
            => (T s t -> T s1 t1 -> T s2 t2 -> T s3 t3)
            -> Tensor (n ': s) t
            -> Tensor (n ': s1) t1
            -> Tensor (n ': s2) t2
            -> Tensor (n ': s3) t3
zipWith3T = Zip3T Sat typeSShape typeSShape typeSShape

-- | Size-preserving convolution operation.
convolution :: forall outputChannels filterSpatialShape inChannels s t.
               KnownShape s => KnownNat inChannels => KnownNat outputChannels => KnownShape filterSpatialShape
            => KnownAlgebraic t
            => Length filterSpatialShape <= 3
            => Length s ~ Length filterSpatialShape
            => T (s ++ '[inChannels]) t -- ^ input tensor
            -> T (filterSpatialShape ++ '[inChannels,outputChannels]) t -- ^ filters
            -> T (s ++ '[outputChannels]) t
convolution x filters = knownAppend @s @'[outputChannels] ?>
                        knownAppend @s @'[inChannels] ?>
  squeeze0 (Convolution (natSat @1) (natSat @inChannels) (natSat @outputChannels) (typeSShape @filterSpatialShape) (typeSShape @s)
             (expandDim0 x)
             filters)


-- | Softmax along the first dimension
softmaxInternal :: forall bs n w. KnownNat bs => KnownBits w => KnownNat n => T '[bs,n] ('Typ 'Float w) -> T '[bs,n] ('Typ 'Float w)
softmaxInternal = Softmax (natSat @bs) (natSat @n)

softmax0 :: forall n w.  KnownBits w => KnownNat n
         => T '[n] (' Typ 'Float w) -> T '[n] ('Typ 'Float w)
softmax0 = reshape . softmaxInternal . reshape @[1,n]

-- | Softmax along the second dimension
softmax1 :: forall n m w.  KnownBits w => KnownNat n => KnownNat m
         => T '[m,n] ('Typ 'Float w) -> T '[m,n] ('Typ 'Float w)
softmax1 = mapT softmax0

argmaxInternal :: forall n s0 s1 t u. KnownNat n => KnownNumeric t => KnownBits u => Sat KnownNat n -> SShape s0 -> SShape s1 -> T (s0 ++ (n ': s1)) t -> T (s0 ++ s1) ('Typ 'Int u)
argmaxInternal _n s0 s1 = UnOp (Axis1Op s1 (ArgMax (natSat @n))) s0

axisSplitApp :: Axis n s -> (Take n s ++ Drop n s) :~: s
axisSplitApp AxZero = Refl
axisSplitApp (AxSucc n) = case axisSplitApp n of
  Refl -> Refl


axisSplitApp' :: Axis n s -> (Take n s ++ (At n s ': Drop ('Succ n) s)) :~: s
axisSplitApp' AxZero = Refl
axisSplitApp' (AxSucc n) = case axisSplitApp' n of
  Refl -> Refl


-- | Argmax along axis @n@
argmax :: forall m n u s t. (KnownShape s, KnownBits u, KnownNat m, KnownNumeric t) => Axis n s -> Tensor (Take n s ++ (m ': Drop n s)) t -> Tensor s ('Typ 'Int u)
argmax n = case axisSplitApp n of
  Refl -> argmaxInternal (natSat @m) (sShapeTake' n (typeSShape @s)) (sShapeDrop' n s)
  where s = typeSShape @s

-- | Argmax along the first dimension
argmax0 :: forall u n s t. (KnownNat n, KnownShape s, KnownBits u, KnownNumeric t) => T (n ': s) t -> T s ('Typ 'Int u)
argmax0 = argmaxInternal (natSat @n) Unit (typeSShape @s)

-- | Argmax along the second dimension
argmax1 :: forall u m n s t. (KnownNat n, KnownNat m, KnownShape s, KnownBits u, KnownNumeric t) => T (m ': n ': s) t -> T (m ': s) ('Typ 'Int u)
argmax1 = argmaxInternal (natSat @n) (natSat @m :* Unit) (typeSShape @s)
-- argmax1 = mapT argmax0 -- equivalent?

-- | Cast the element type.
cast :: forall u s t. KnownTyp t => KnownShape s => KnownTyp u => T s t -> T s u
cast = appRUnit @s #> UnOp Cast (typeSShape @s)

-- | (dense) softmax cross entropy with logits.
softmaxCrossEntropyWithLogits :: forall numClasses.
     KnownNat numClasses => Tensor '[numClasses] Float32 -- ^ labels
  -> Tensor '[numClasses] Float32 -- ^ logits
  -> Tensor '[] Float32
softmaxCrossEntropyWithLogits  =
  BinOp SoftmaxCrossEntropyWithLogits
  Unit (typeSShape @ '[numClasses]) typeSTyp (typeSShape @ '[numClasses]) typeSTyp


-- | Computes sigmoid cross entropy given logits. Measures the
-- probability error in discrete classification tasks in which each
-- class is independent and not mutually exclusive. For instance, one
-- could perform multilabel classification where a picture can contain
-- both an elephant and a dog at the same time. See
-- https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
sigmoidCrossEntropyWithLogits :: forall s w.
  KnownBits w => KnownShape s => Tensor s (Flt w) -- ^ labels
                              -> Tensor s (Flt w) -- ^ logits
                              -> Tensor s (Flt w)
sigmoidCrossEntropyWithLogits  =
  appRUnit @s #> BinOp SigmoidCrossEntropyWithLogits 
    (typeSShape @s)      Unit typeSTyp Unit typeSTyp

-- | sparse softmax cross entropy with logits.
sparseSoftmaxCrossEntropyWithLogits :: forall numClasses t.
   KnownNat numClasses => KnownBits t =>
  Tensor '[] Int32                   -- ^ desired label
  -> Tensor '[numClasses] (Flt t) -- ^ predictions for each label
  -> Tensor '[] (Flt t) 
sparseSoftmaxCrossEntropyWithLogits  =
  BinOp SparseSoftmaxCrossEntropyWithLogits Unit Unit typeSTyp (typeSShape @ '[numClasses]) typeSTyp

reverseT :: KnownTyp t => KnownNat n => T '[n] t -> T '[n] t
reverseT = UnOp (Axis1Op Unit (ReverseT Sat)) Unit

-- | One hot vector along axis 0
oneHot0 :: forall numClasses w s t. KnownNat numClasses => KnownNumeric t => KnownBits w =>
  (KnownShape s) =>
  Tensor s ('Typ 'Int w) -> Tensor (numClasses ': s) t
oneHot0 = UnOp (Axis1Op (typeSShape @s) (OneHot Sat)) Unit

-- | One hot vector along axis 1
oneHot1 :: forall numClasses w s m t. KnownBits w =>KnownShape s => KnownNat numClasses => KnownNat m => KnownNumeric t => Tensor (m ': s) ('Typ 'Int w) -> Tensor (m ': numClasses ': s) t
oneHot1 = mapT oneHot0

-- | Generate a random tensor whose distribution is given. A new noise
-- is sampled for each element in a batch.
noise :: KnownShape s => Distribution s t -> Gen (T s t)
noise d = do
  noiseId <- GPId -- necessary for correct broadcasting behaviour
  return $ Noise noiseId Unit typeSShape d

-- | Clip a tensor
clipByValue :: forall s t. KnownShape s => KnownBits t => Float -> Float -> T s (Flt t) -> T s (Flt t)
clipByValue lo hi = appRUnit @s #> UnOp (Float1Op (ClipByValue lo hi)) (typeSShape @s)

-- | (where_ c x y)[i] = if c[i] then x[i] else y[i]
where_ :: T s TFBool -> T s t -> T s t -> T s t
where_ = Where


-- | Selection of a tensor (note: this is a strict operation)
if_ :: forall s t. KnownShape s => Scalar TFBool -> T s t -> T s t -> T s t
if_ = If -- FIXME: part of the workaround for https://github.com/tensorflow/tensorflow/issues/21901
-- if_ x = appRUnit @s $ where_ (broadcastTT @s x)

-- | @(gather x ix)[k] = x[ix[k]]@. See https://www.tensorflow.org/api_docs/python/tf/gather
gather :: forall n indexShape s t. KnownShape s => KnownNat n => KnownShape indexShape => T (n ': s) t -> T indexShape Int32 -> T (indexShape ++ s) t
gather = Gather typeSShape Unit (natSat @n) typeSShape
-- gather params ix = GatherND (typeSShape @'[n]) (typeSShape @s) (typeSShape @indexShape) params $
--   prodHomo @indexShape @'[1] $
--   (reshapeAuto ix)

-- | @(lookup i xs) = xs[i]@. This function returns an element of a
-- tensor at a dynamic index. This is a version of 'gather'
-- specialised to a scalar index.
lookupT :: KnownShape xs => KnownNat n => Scalar Int32 -> Tensor (n ': xs) t -> Tensor xs t
lookupT ix xs = gather xs ix

-- | x by y maxpool layer.
maxPool2D :: forall windowx windowy height width channels t.
             KnownNat height => KnownNat width
          => KnownNat channels
          => (KnownNat windowx, KnownNat windowy, KnownBits t) =>
             T '[windowx*width,windowy*height,channels] (Flt t)
          -> T '[width,height,channels] (Flt t)
maxPool2D x = squeeze0 (Pool (natSat @1) (typeSShape @'[windowx,windowy]) MaxPool (natSat @channels) (typeSShape @'[width,height]) (expandDim0 x))

-- | maxpool layer. window size is the first type argument.
maxPool1D :: forall window width channels t.
             KnownNat width => KnownNat channels => (KnownNat window,KnownBits t) =>
             T '[window*width,channels] (Flt t) -> T '[width,channels] (Flt t)
maxPool1D x = squeeze0 (Pool (natSat @1) (typeSShape @'[window]) MaxPool (natSat @channels) (typeSShape @'[width]) (expandDim0 x))


doExtractVars :: Gen a -> (a, GState, [VarInfo])
doExtractVars p = runRWS (extractVars p) () initialGstate

extractVars :: Gen a -> RWS () [VarInfo] GState a
extractVars (GPState f) = state f
extractVars GPId = do
  GState {..} <- get
  put GState {nextVar=nextVar+1,..}
  return nextVar
extractVars (GPVariable trainable name i) = do
  -- i <- mapM extractVars initial
  case i of
    Nothing -> return ()
    Just i' -> when (not (null (freeVarsT i'))) $ error "aaaaaaaaarrrrghhh"
  GState {..} <- get
  let r = Ref name typeSShape typeSTyp
  tell [VarInfo trainable r i]
  return r
extractVars (GPApp a b) = do f <- extractVars a; x <- extractVars b; return (f x)
extractVars (GPBind a f) = do
  a' <- extractVars a
  extractVars (f a')
extractVars (GPReturn x) = return x



================================================
FILE: TypedFlow/Broadcast.hs
================================================
{-# LANGUAGE InstanceSigs #-}
{-|
Module      : TypedFlow.Abstract
Description : Abstract Tensor representations
Copyright   : (c) Jean-Philippe Bernardy, 2018
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental

This module provides operations on the abstract representation of
tensor operations. It is not normally imported directly by users.
-}

{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE CPP #-}
#if __GLASGOW_HASKELL__ >= 806
{-# LANGUAGE NoStarIsType #-}
#endif

module TypedFlow.Broadcast (
  -- * broadcasting
  doBroadcast,doBroadcastSingle,mapPlaceHolders, ConsSh, unopInputShape,
  -- * helpers which are also useful elsewhere
  -- ** reshapes
  reshape, reshapeAuto, reshapeFrom, reshapeTo, inflate2, flatten2,
  permToFun, 
  -- ** transpositions
  transpose01, transposeN, transposeN', transposeN01,
  -- ** others
  concatT', range,
  ) where

import Control.Monad.State
-- import Data.Kind (Type,)
import Data.Proxy
import Data.Type.Equality
import GHC.TypeLits
import Prelude hiding (RealFrac(..))
import System.IO.Unsafe
import TypedFlow.Memo2 hiding (Comp)
import TypedFlow.Types (T(..), type (∘)(..))
import TypedFlow.Types hiding (T)
import TypedFlow.Types.Proofs


data GS = GS { gsUnique :: Unique }


type G = StateT GS IO 

runG :: Unique -> G x -> x
runG u m = fst (unsafePerformIO  (runStateT m GS { gsUnique = u }))

doBroadcastSingle :: forall s t. (KnownShape s, KnownTyp t) => T s t -> T s t
doBroadcastSingle x = case doBroadcast @'[ '("_doBroadcastSingle" , s , t) ] (PHT x :* Unit) of
  PHT x' :* Unit -> x'
  
  

doBroadcast :: All KnownPlaceholder ps => Placeholders ps -> Placeholders ps
doBroadcast phs = runG 0 $ do
  F3m' bc <- mkBroadcastFn
  let broadcast :: forall n s t. BroadcastFn n s t
      broadcast = unwrapBCFn bc
  F2m' gBC' <- mkGenerateBC broadcast
  let generateBC :: forall s t. GenBCFn s t
      generateBC = unwrapGBCFn gBC'
      generateBCMany :: forall ps. All KnownPlaceholder ps => Placeholders ps -> G (Placeholders ps)
      generateBCMany = \case
        Unit -> return Unit
        (PHT x :* xs) -> do
          x' <- generateBC x
          xs' <- generateBCMany xs
          return (PHT x' :* xs')
  generateBCMany phs


getUnique :: G Unique
getUnique = do
  u <- gets ((1+) . gsUnique)
  modify $ \GS {} -> GS {gsUnique = u,..}
  return u

generateBC' :: (forall n s t proxy. KnownTyp t => KnownShape s => KnownNat n => Unique -> Bool -> proxy n -> T s t -> G (T (n : s) t))
            -> (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> G (T s' t'))
            -> forall s t. KnownTyp t
            => SShape s
            -> T s t
            -> G (T s t)
generateBC' broadcast rec (n@Sat :* sR) (Zip3T _ _s1 _s2 _s3 f x y z) = knownSShape sR ?> do
  u <- getUnique
  -- ATTN: it is critical not to do recursive calls to x,y,z here. Doing so would create new nodes, loosing sharing, and creating problems down the line.
  a' <- rec sR (f (Unbroadcast n u x) (Unbroadcast n u y) (Unbroadcast n u z))
  broadcast u False n a'
generateBC' broadcast rec (n@Sat :* sR) (ZipT _ _s1 _s2 f x y) = knownSShape sR ?> do
  u <- getUnique
  a' <- rec sR (f (Unbroadcast n u x) (Unbroadcast n u y))
  broadcast u False n a'
generateBC' broadcast rec (n@Sat :* sR) (MapT _ _s' f x) = knownSShape sR ?> do
  u <- getUnique
  a' <- rec sR (f (Unbroadcast n u x))
  broadcast u False n a'
generateBC' broadcast rec (n@Sat :* sR) (BroadcastT maybeUnique varyNoise _ _s' a) = knownSShape sR ?> do
  u <- case maybeUnique of
          Nothing -> getUnique
          Just u' -> return u'
  a' <- rec sR a
  broadcast u varyNoise n a'
generateBC' _ _ _ (n@T {}) = return n
generateBC' _ _ _ (n@Noise {}) = return n
generateBC' _ rec _ (BinOp op s0 s1 t1 s2 t2 x y) = knownTyp t1 $ knownTyp t2 $ BinOp op s0 s1 t1 s2 t2 <$> (rec (s0 .+. s1) x) <*> (rec (s0 .+. s2) y)
generateBC' _ rec _ (UnOp op s0 x) = UnOp op s0 <$> rec (s0 .+. unopInputShape op) x
generateBC' _ rec sR (Unbroadcast p u' x) = Unbroadcast p u' <$> rec (p :* sR) x
generateBC' _ rec _ (DirectBroadcast s0 s1 s2 s3 x) = DirectBroadcast s0 s1 s2 s3 <$> (rec (s0 .+. s2) x)
generateBC' _ rec _ (ReshapeFrom s0 x) = reshapeFrom s0 <$> rec s0 x
generateBC' _ rec _ (Transpose s0 t x) = Transpose s0 t <$> (rec s0 x)
generateBC' _ rec _ (Concat s0 s1 xs) = Concat s0 s1 <$> hTraverse (\(Catable m x) -> Catable m <$> (rec (s0 .+. m :* s1) x)) xs
generateBC' _ rec _ (Gather is s0 m s1 x ix) = Gather is s0 m s1 <$> (rec (s0 .+. m :* s1) x) <*> rec (s0 .+. is) ix
generateBC' _ rec _ (GatherND cs es is x ix) = GatherND cs es is <$> (rec (cs .+. es) x) <*> (rec (is *: sListLenAsNat cs) ix)
generateBC' _ rec _ (MatMul s0 a b c x y) = MatMul s0 a b c <$> (rec (s0 .+. a :* b :* Unit) x) <*> (rec (s0 .+. b :* c :* Unit) y)
generateBC' _ rec sR (Where cond x y) = Where <$> rec sR cond <*> rec sR x <*> rec sR y
generateBC' _ rec sR (If cond x y) = If <$> rec Unit cond <*> rec sR x <*> rec sR y
generateBC' _ rec _ (Convolution bs@Sat inChans outChans filterShape s0 x filters) = Convolution bs inChans outChans filterShape s0 <$> (rec (bs :* (s0 *: inChans)) x) <*> (rec (filterShape .+. inChans :* outChans :* Unit) filters)
generateBC' _ rec _ (Pool bs@Sat window pt numChans outSpatial x) = Pool bs window pt numChans outSpatial <$> rec (bs :* (zipWithMulSShapes window outSpatial *: numChans)) x
generateBC' _ rec _ (Softmax bs n x) = Softmax bs n <$> (rec (bs :* n :* Unit) x)
generateBC' _ _ _ _ = error "generateBC': unhandled case"



(<&&>) :: Applicative f => f Bool -> f Bool -> f Bool
x <&&> y = (&&) <$> x <*> y

-- | True if the argument does not contain an expression which should be broadcast.
protoFinished :: Unique -> Bool -> (forall s' t'. Unique -> Bool -> T s' t' -> G Bool) -> T s t -> G Bool
protoFinished u varyNoise rec0 =
  let rec :: forall s t. T s t -> G Bool
      rec = rec0 u varyNoise
  in \case
    BroadcastT _ _ _ _s a -> rec a
    MapT _ s f x -> rec x <&&> rec (f (T (Variable (Ref 0 s typeSTyp))))
    ZipT _ s0 s1 f x y -> rec x <&&> rec y <&&> rec (f (T (Variable (Ref 0 s0 typeSTyp))) (T (Variable (Ref 0 s1 typeSTyp))))  
    Zip3T _ s0 s1 s2 f x y z -> rec x <&&> rec y <&&> rec z <&&> rec (f (T (Variable (Ref 0 s0 typeSTyp))) (T (Variable (Ref 0 s1 typeSTyp))) (T (Variable (Ref 0 s2 typeSTyp))))  
    Softmax _ _ x -> rec x
    DirectBroadcast _ _ _ _ x -> rec x
    GatherND _ _ _ x y -> rec x <&&> rec y
    Noise _ _ _ _ -> return (not varyNoise)
    Where cond x y -> rec cond <&&> rec x <&&> rec y
    If cond x y ->  rec cond <&&> rec x <&&> rec y
    T _ -> return True
    Unbroadcast _p u' _x -> return (u /= u')
    UnOp _op _ x -> rec x
    MatMul _ _ _ _ x y -> rec x <&&> rec y
    BinOp _op _ _ _ _ _ x y -> rec x <&&> rec y
    Gather _is _s0 _m _s1 x ix -> rec x <&&> rec ix
    Transpose _ _t x -> rec x
    ReshapeFrom _s x -> rec x
    Concat _s0  _s1 xs -> (and . htoList) <$> hTraverse (\(Catable _ x) -> K <$> rec x) xs
    Convolution _bs _inChans _outChans _filterShape _s x filters -> rec x <&&> rec filters
    Pool _ _ _ _ _ x  -> rec x
    -- _ -> error "protoFinished: unhandled case"

data K02 t x y = K02 {fromK02 :: t}


mkFinished :: G (F2m G (Sig02 Bool (Sig02 Unique T)) (K02 Bool) ) -- forall s' t'. Unique -> Bool -> T s' t' -> G (F2m _)
mkFinished = memo2 (ordMap @Bool `containing02` (ordMap @Unique `containing02` snMap2 @T)) $
             \rec (Ex02 u (Ex02 v x)) -> K02 <$> protoFinished v u (unwrapFin rec) x

unwrapFin :: ((Sig02 Bool (Sig02 Unique T)) s t -> G (K02 Bool s t)) -> Unique -> Bool -> T s t -> G Bool
unwrapFin f u v x = fromK02 <$> f (Ex02 v (Ex02 u x))

data KT s t where
  KT ::  STyp t -> SShape s -> KT s t

type GenBCFn s t = (KnownTyp t, KnownShape s) => T s t -> G (T s t)


unwrapGBCFn :: forall s t. (T s t -> KT s t -> G (T s t)) -> GenBCFn  s t
unwrapGBCFn f x' = f x' (KT typeSTyp typeSShape)

-- isBroadcastT :: T s t -> Bool
-- isBroadcastT (BroadcastT {}) = True
-- isBroadcastT _ = False

mkGenerateBC :: (forall n s t. BroadcastFn n s t) -> G (F2m' G T KT T)
mkGenerateBC broadcast = memo2' (snMap2 @T) $
                         \rec x (KT t s) -> knownTyp t $ do
                           r <- generateBC' broadcast (\sh' x' -> rec x' (KT typeSTyp sh')) s x
                           -- when (isBroadcastT r) $ liftIO $ putStrLn "YIKES!"
                           return r

newtype BC'd (n :: Nat) (s :: Shape) (t :: Typ) = BC'd {fromBC'd :: (T (n : s) t)}

data KTn n s t where
  KTn ::  STyp t -> SShape s -> KTn n s t

type BroadcastFn n s t = forall proxy. (KnownNat n, KnownShape s, KnownTyp t) => Unique -> Bool -> proxy n -> T s t -> G (T (n : s) t)

unwrapBCFn :: ((Sig03 Unique (Sig03 Bool (Sig12 (Sat KnownNat) T))) n s t -> KTn n s t -> G (BC'd n s t)) -> BroadcastFn n s t
unwrapBCFn f u v _n x' = fromBC'd <$> f (Ex03 u (Ex03 v (Ex12 natSat x'))) (KTn typeSTyp typeSShape)

mkBroadcastFn :: G (F3m' G (Sig03 Unique (Sig03 Bool (Sig12 (Sat KnownNat) T))) KTn BC'd)
mkBroadcastFn = do
  F2m fin <- mkFinished
  memo3' (ordMap @Unique `containing03` (ordMap @Bool `containing03` (verifMap1 @(Sat KnownNat) `containing12` snMap2 @T))) $
    \rec (Ex03 u (Ex03 v (Ex12 n x))) (KTn st sh) ->
      BC'd <$> protoBroadcast u v n
                  (\sh' x' -> fromBC'd <$> rec (Ex03 u (Ex03 v (Ex12 n x'))) (KTn typeSTyp sh'))
                  (unwrapFin fin u v) st sh x


class ConsSh (x :: Nat) (p :: (Symbol,Shape,Typ))
instance Fun (ConsSh x) where type Ap (ConsSh x) p = '(Frst3 p,x ': Scnd3 p,Thrd3 p)

-- -- | Turns a tensor of indices in a container into a tensor of indices
-- -- in a container of higher rank. The added indexed dimension
-- -- corresponds to the first dimension of the index.
-- broadcastIndex :: forall n containerShape indexShape w.
--   KnownBits w => Sat KnownNat n ->
--   SShape containerShape ->
--   SShape indexShape ->
--   IndexTensor (n ': indexShape) containerShape w ->
--   IndexTensor (n ': indexShape) (n ': containerShape) w
-- broadcastIndex n cs = broadcastIndex' n (sListLenAsNat cs)

broadcastIndex' :: forall n containerRank indexShape w.
  KnownBits w => Sat KnownNat n ->
  Sat KnownNat containerRank ->
  SShape indexShape ->
  T (n ': indexShape ++ '[containerRank])  ('Typ 'Int w) ->
  T (n ': indexShape ++ '[1 + containerRank]) ('Typ 'Int w)
broadcastIndex' n@(Comp Dict) cr is ix = concatT' ((:*) n is) (natSat @1) cr Unit nIndex ix
  where nIndex :: T (n ': indexShape ++ '[1]) ('Typ 'Int w)
        nIndex = DirectBroadcast Unit Unit ((:*) n Unit) (is .+. (:*) (natSat @1) Unit) range

-- directBroadcast0 :: forall n s t. KnownShape s => KnownNat n => T s t -> T (n:s) t
-- directBroadcast0 = appRUnit @s #> DirectBroadcast Unit ((:*) (natSat @n) Unit) (typeSShape @s) Unit

-- broadcastIndexMany :: forall n containerShape indexShape w.
--   KnownBits w =>
--   Sat KnownNat n ->
--   SShape containerShape ->
--   SShape indexShape ->
--   IndexTensor indexShape '[n] w ->
--   IndexTensor (containerShape ++ indexShape) (containerShape ++ '[n]) w
-- broadcastIndexMany _ Unit _ x = x
-- broadcastIndexMany n ((:*) m@Sat cs) is x =
--   knownSShape (cs .+. (*:) is (sListLenAsNat (cs *: n))) ?>
--   -- (m : cs ++ is ++  '[(Length (m : cs ++ [n]))])
--   (broadcastIndex m ((*:) cs n) (cs .+. is) $
--   -- (m : (cs ++ is ++  '[Length (cs ++ [n])]))
--   (appAssocS cs is ((:*) (sListLenAsNat (cs *: n)) Unit) #>
--   -- (m : cs ++ is ++ '[Length (cs ++ [n])])
--   directBroadcast0 $
--   -- (cs ++ is ++  '[Length (cs ++ [n])])
--   broadcastIndexMany n cs is x))
--   -- is

--  Product (filterSpatialShape ++ '[inChannels, outChannels * n])
-- Product ((filterSpatialShape ++ '[inChannels, outChannels]) ++ '[n])

axisOpInputShape :: Axis1Op s1 t s2 u -> SShape s1
axisOpInputShape o = case o of
  ArgMax n -> HSingle n
  OneHot _n -> Unit
  ReduceOp n _ -> HSingle n
  ReverseT n -> HSingle n
  SliceOp _ n _ _ -> HSingle n
  AccessOp n _ -> HSingle n

unopInputShape :: UnOp s t s' t' -> SShape s
unopInputShape (Diag n) = n :* Unit
unopInputShape Cast = Unit
unopInputShape (Axis1Op s o) = axisOpInputShape o .+. s
unopInputShape StopGradient = Unit
unopInputShape (Num1Op _) = Unit
unopInputShape (Float1Op _) = Unit
unopInputShape (ExpM n) = n :* n :* Unit
unopInputShape (ZeroTriangle n _ _) = n :* n :* Unit
unopInputShape Conjugate = Unit
unopInputShape RealPart = Unit

protoBroadcast :: forall n s t.
  Unique -- unique identifier marking the variable tensor which will be marking inputs (not to broadcast).
  -> Bool -- how to expand the noise? (If True use different noise for all indices)
  -> Sat KnownNat n -- added dimension's size
  -> (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> G (T (n ': s') t')) -- recursive case
  -> (forall s' t'. T s' t' -> G Bool) -- test if we're done
  -> STyp t -- representation of the type
  -> SShape s -- representation of the shape
  -> T s t -- tensor (expression) to broadcast
  -> G (T (n ': s) t) -- return broadcated expression (on 1st position)
protoBroadcast u varyNoise n@(Comp Dict) rec finished ty s tensor = do
  isFinished <- finished tensor
  case isFinished of
    True -> simpleBC
    False -> knownTyp ty $ case tensor of
      BroadcastT {} -> error "BroadcastT case remaining, this should have been dealt with by generateBC"
      MapT {} -> error "MapT case remaining, this should have been dealt with by generateBC"
      ZipT {} -> error "ZipT case remaining, this should have been dealt with by generateBC"
      Zip3T {} -> error "Zip3T case remaining, this should have been dealt with by generateBC"
      Softmax bs@Sat m@Sat x -> prodAssocS n bs m #> do
        x' <- rec (typeSShape) x
        return (reshapeAuto (Softmax (satMul n bs) m (reshapeAuto x')))
      DirectBroadcast s0 s1 s2 s3 x -> do
        x' <- (rec (s0 .+. s2) x)
        return (DirectBroadcast (n :* s0) s1 s2 s3 x')
      GatherND cs es is x ix -> do
        xFinished <- finished x
        case xFinished of
          True -> GatherND cs es (n :* is) x <$> (rec (is *: sListLenAsNat cs) ix)
          False -> do
              ix' <- rec (is *: sListLenAsNat cs) ix
              x' <- (rec (cs .+. es) x)
              return (GatherND (n :* cs) es (n :* is) x' (broadcastIndex' n (sListLenAsNat cs) is ix'))
      Noise v s0 s1 x -> if varyNoise then return (Noise v (n :* s0) s1 x) else simpleBC
      -- When varying noise, then we extend the shape of the noise (so
      -- more stuff is sampled), otherwise we copy the noise using simple
      -- broadcasting
      Pool bs@Sat window pt numChans outSpatial x ->
        (knownSShape (zipWithMulSShapes window outSpatial *: numChans) ?>
         (prodAssocS n bs (productS (zipWithMulSShapes window outSpatial *: numChans)) #>
         (prodAssocS n bs (productS (outSpatial *: numChans)) #> do
             x' <- (rec typeSShape x)
             return $ (reshapeFrom (satMul n bs :* outSpatial *: numChans) $
                       Pool (satMul n bs) window pt numChans outSpatial (reshapeAuto x')))))
      Where cond x y -> Where <$> (rec s cond) <*> (rec s x) <*> (rec s y)
      If cond x y -> do
        condFinished <- finished cond
        case condFinished of
          True -> If cond <$> (rec s x) <*> (rec s y)
          False ->  error "broadcast on 'if' condition not implemented"
      T _ -> error "panic: broadcast constant should be finished!"
      Unbroadcast p@Sat u' x
        | u == u' -> return $ case testEq p n of
            Nothing -> UnOp (error "panic.unbroadcast.unit") Unit x
            Just Refl -> x
        | otherwise -> knownSShape s ?> do
            x' <- (rec (p :* s) x)
            return (Unbroadcast p u' (transpose01 x'))
          -- An uncomplete broadcast (in another dimension).
      MatMul s0 a@Sat b@Sat c@Sat x y -> do
        yFinished <- finished y
        case (s0,yFinished) of
           (Unit,True) -> do
             -- this optimisation is absolutely critical to implement dense
             -- layers efficiently (at least with TF 1.3). (about 10x performance increase)
             x' <- (rec (a :* b :* Unit) x)
             return $ inflate2 (MatMul Unit (satMul n a) b c (flatten2 x') y)
           _ -> MatMul (n :* s0) a b c  <$> (rec (s0 .+. a :* b :* Unit) x) <*> (rec (s0 .+. b :* c :* Unit) y)
      BinOp op s0 s1 t1 s2 t2 x y -> knownTyp t1 $ knownTyp t2 $ do
        BinOp op (n :* s0) s1 t1 s2 t2 <$> (rec (s0 .+. s1) x) <*> (rec (s0 .+. s2) y)
      UnOp op s0 x -> UnOp op (n :* s0) <$> (rec (s0 .+. unopInputShape op) x)
      Gather is s0 m s1 x ix -> do
        xFinished <- finished x
        case (s0,xFinished) of -- this optimisation is important to get efficient embeddings (???)
          (Unit,True) -> Gather (n :* is) Unit m s1 x <$> (rec is ix)
          _ -> Gather is (n :* s0) m s1 <$> (rec (s0 .+. m :* s1) x) <*> (rec (s0 .+. is) ix)
      Transpose s0 t x -> Transpose (n :* s0) (PermSkip t) <$> (rec s0 x)
      ReshapeFrom s0 x -> reshapeFrom (n :* s0) <$> (rec s0 x)
      Concat s0 s1 xs -> do
        Concat (n :* s0) s1 <$> hTraverse (\(Catable m x) -> Catable m <$> (rec (s0 .+. m :* s1) x)) xs
      Convolution bs@(Sat) inChans outChans filterShape s0 x filters -> do
        filtersFinished <- finished filters
        xFinished <- finished x
        case (filtersFinished,xFinished) of
          (True,_) ->
            prodAssocS n bs (productS (s0 *: inChans))  #>
            prodAssocS n bs (productS (s0 *: outChans)) #>
            knownSShape (s0 *: inChans)                 ?> do
              x' <- (rec typeSShape x)
              return $ reshapeFrom (satMul n bs :* s0 *: outChans) 
                      (Convolution (satMul n bs) inChans outChans filterShape s0 (reshapeAuto x') filters)
          (_,True) ->
            knownSShape (filterShape .+. inChans :* outChans :* Unit) ?>
            knownSShape (bs :* s0 .+. outChans :* Unit) ?> do
              filters' <- rec typeSShape filters
              return $ transposeN' $
                reshapeProven (ANat bs !:* AShape s0 *:! (ANat outChans :*: ANat n))
                              ((ANat bs !:* AShape s0 *:! ANat outChans) *:! ANat n) $
                Convolution bs inChans (outChans `satMul` n) filterShape s0 x $
                reshapeProven ((AShape filterShape :++: (ANat inChans !:* Single (ANat outChans))) *:! ANat n)
                              (AShape filterShape :++: ANat inChans !:* Single (ANat outChans :*: ANat n)) $
                transposeN $ filters'

          _ -> error "broadcast on both convolution filter and data not implemented"
      _ -> error "protoBroadcast: unhandled case" 
  where simpleBC :: G (T (n ': s) t)
        simpleBC = appRUnit @s #>
                   return (DirectBroadcast Unit (n :* Unit) s Unit tensor)

inversePerm :: Permutation a b -> Permutation b a
inversePerm PermId = PermId
inversePerm (PermSkip x) = PermSkip (inversePerm x)
inversePerm PermSwap = PermSwap
inversePerm (PermTrans x y) = PermTrans (inversePerm y) (inversePerm x)

permToFun :: Permutation s t -> Integer -> Integer
permToFun = \case
  PermId -> \x -> x
  PermTrans a b -> permToFun b . permToFun a
  PermSwap -> \case
    0 -> 1
    1 -> 0
    x -> x
  PermSkip p -> \case
    0 -> 0
    x -> permToFun p (x-1) + 1

reshapeAuto :: forall s s0 t. KnownShape s0 => Product s ~ Product s0 => T s0 t -> T s t
reshapeAuto = reshapeFrom typeSShape

reshapeProven :: forall s s0 t n. ShapeX s0 n -> ShapeX s n -> T s0 t -> T s t
reshapeProven s1 s2 = case decideProductEq s1 s2 of
                        Refl -> knownSShape (exprSShape s1) ?> reshapeAuto

reshapeTo :: forall s s0 t proxy. KnownShape s0=> Product s ~ Product s0 => proxy s -> T s0 t -> T s t
reshapeTo _ = reshapeAuto

reshapeFrom :: forall s s0 t. Product s ~ Product s0 => SShape s0 -> T s0 t -> T s t
reshapeFrom _ (ReshapeFrom s1 x) = ReshapeFrom s1 x -- avoid reshaping over and over
reshapeFrom s0 x = ReshapeFrom s0 x


type BatchedPlaceholders n ps = Placeholders (BPH n ps)
type BPH n ps = (Ap (FMap (ConsSh n)) ps)

-- | Batch the model (adding one dimension).
mapPlaceHolders :: forall batchSize shapesAndTypes resShapesAndTypes.
    (KnownNat batchSize, KnownLen shapesAndTypes, KnownLen resShapesAndTypes, All KnownPlaceholder shapesAndTypes, All KnownPlaceholder resShapesAndTypes)
  => Unique
  -> Bool
  -> (Placeholders shapesAndTypes -> Placeholders resShapesAndTypes)
  -> BatchedPlaceholders batchSize shapesAndTypes -> (BatchedPlaceholders batchSize resShapesAndTypes)
mapPlaceHolders u varyNoise f xs = broadcastPlacehoders @batchSize typeSList (f (unbroadcastPlacehoders @batchSize typeSList xs))  where
    unbroadcastPlacehoders :: forall n r. KnownNat n => SList r -> BatchedPlaceholders n r -> Placeholders r
    unbroadcastPlacehoders Unit Unit = Unit
    unbroadcastPlacehoders (_ :* ss) (PHT x :* xs') = PHT (Unbroadcast batchSize u x) :* unbroadcastPlacehoders @n ss xs'
      where batchSize = natSat @n

    broadcastPlacehoders :: forall n r. All KnownPlaceholder r => KnownNat n => SList r -> Placeholders r -> (BatchedPlaceholders n r)
    broadcastPlacehoders Unit Unit = Unit
    broadcastPlacehoders (_ :* ss) (PHT x :* xs) =
      let x' = BroadcastT (Just u) varyNoise (natSat @n) typeSShape x
          xs' = broadcastPlacehoders @n ss xs
      in (PHT x' :* xs') 

----------------------------------------------------------------
-- Here start helper functions

permN :: SList s -> Permutation (n ': s) (s ++ '[n])
permN Unit = PermId
permN ((:*) _n s) = PermSwap `PermTrans` PermSkip (permN s)

permN01 :: SList s -> Proxy m -> Proxy n -> Permutation (s ++ [m,n]) (s ++ [n,m])
permN01 Unit _ _ = PermSwap
permN01 ((:*) _n s) m n = PermSkip (permN01 s m n)


-- | Transposition. See the type for the permutation of dimensions.
transposeN :: ∀ s n t. KnownNat n => KnownShape s => T (n ': s) t -> T (s ++ '[n]) t
transposeN  = doTranspose typeSShape (permN (typeSList @s))

-- | Transposition. See the type for the permutation of dimensions.
transposeN' :: ∀ s n t. KnownNat n => KnownShape s => T (s ++ '[n]) t -> T (n ': s) t
transposeN' = doTranspose (typeSShape @s *: (natSat @n)) (inversePerm (permN (typeSList @s)))


-- | Transposition. See the type for the permutation of dimensions.
transposeN01 :: ∀ s m n t. KnownNat n => KnownNat m => KnownShape s => T (s ++ [m,n]) t -> T (s ++ [n,m]) t
transposeN01 = doTranspose (typeSShape @s .+. typeSShape @'[m,n]) (permN01 (typeSList @s) (Proxy @m) (Proxy @n))


-- | Transposition. See the type for the permutation of dimensions.
transpose01 :: ∀ s m n t. KnownNat n => KnownNat m => KnownShape s => T (m ': n ': s) t -> T (n ': m ': s) t
transpose01 = doTranspose typeSShape PermSwap

doTranspose :: SShape s0 -> Permutation s0 s -> T s0 t -> T s t
doTranspose _  p (Transpose sh' q x) = doTranspose sh' (PermTrans q p) x
doTranspose sh p x = Transpose sh p x


-- | Concatenate tensors with explicit shapes.
concatT' :: ∀ s0 d1 d2 s1 t. KnownTyp t =>
    SShape s0 -> Sat KnownNat d1 -> Sat KnownNat d2 -> SShape s1 -> T (s0 ++ (d1 ': s1)) t -> T (s0 ++ (d2 ': s1)) t -> T (s0 ++ ((d1+d2) ': s1)) t
concatT' s0 d1@(Comp Dict) d2@(Comp Dict) s1 x y = Concat s0 s1 (Catable d1 x :* Catable d2 y :* Unit)


-- | Reshape a tensor so that the first dimension is expanded into two.
inflate2 :: ∀ m n s t. KnownTyp t => (KnownNat m, KnownNat n, KnownShape s) => Tensor (m*n ': s) t -> Tensor (m ': n ': s) t
inflate2 = prodAssoc @m @n @(Product s) #> reshape


-- | Reshape a tensor so that the first two dimensions are collapsed
flatten2 :: ∀ m n s t. KnownTyp t => (KnownNat m, KnownNat n, KnownShape s) => Tensor (m ': n ': s) t -> Tensor (m*n ': s) t
flatten2 = prodAssoc @m @n @(Product s) #> reshape


reshape :: ∀ s2 s1 t. KnownShape s1 => KnownShape s2 => Product s1 ~ Product s2 => Tensor s1 t -> Tensor s2 t
reshape = reshapeAuto

-- | range[i] = i
range :: forall n w. KnownNat n => KnownBits w => T '[n] ('Typ 'Int w)
range = T (Range (natSat @n))


================================================
FILE: TypedFlow/Haskell.hs
================================================
{-|
Module      : TypedFlow.Haskell
Description : Generation of computation graph using tensorflow haskell. 
Copyright   : (c) Jean-Philippe Bernardy, 2017
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental

-}

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}

module TypedFlow.Haskell where

import Data.Type.Equality
import Data.List (genericReplicate)
import GHC.TypeLits
import Control.Monad.State
import TypedFlow.Types
import TypedFlow.Types.Proofs
import TypedFlow.Abstract (newId, permToFun, unopInputShape)
import TypedFlow.Memo
import System.Mem.StableName
import System.IO.Unsafe

import qualified Data.Int as Backend

import qualified TensorFlow.Core        as Backend
import qualified TensorFlow.GenOps.Core as BackCore
import qualified TensorFlow.Minimize    as Backend
import qualified TensorFlow.Ops         as Backend
import qualified TensorFlow.NN          as Backend
-- import qualified TensorFlow.Variable    as Backend
import qualified TensorFlow.Tensor

import qualified Data.IntMap as IM
import Data.IntMap (IntMap)

type BackendShape = BackendTensor ('Typ 'Int 'B32)
type BackendTensor t = Backend.Tensor Backend.Build (HaskType t)
type BackendVariable t = Backend.Tensor Backend.Ref (HaskType t)
type BackendTensorType t = Backend.TensorType (HaskType t)

shapeFromType :: ∀ (s :: Shape). KnownShape s => BackendShape
shapeFromType = shapeVector (typeSShape @s)

-- | Show a shape, but "None" is replaced by "-1"
shapeVector :: forall (s::Shape) proxy. All KnownNat s => SList' proxy s -> BackendShape
shapeVector s = shapeFromList (shapeToList'' s)

permToTensor :: SShape s -> Permutation s t -> Backend.Tensor Backend.Build Backend.Int32
permToTensor s p = Backend.vector (map (fromInteger . permToFun p) [0.. sListLength s])

shapeFromList :: [Integer] -> BackendShape
shapeFromList = Backend.vector . map convertNone

showShapeLen :: ∀ (s::Shape). KnownLen s => Backend.Int32
showShapeLen = fromIntegral (listTypeLen @ s)

convertNone :: Num a => Integer -> a
convertNone n = (if n == 514229 then (-1) else fromIntegral n)

-- runWithFeeds

data BT (s :: Shape) (t :: Typ) where
  BT :: forall s t. (BackendTensor t) -> BT s t

data HState = HState {genVars :: IntMap Var
                     ,genPureTable :: SNMap22 Shape Typ T BT
                     -- alternative: use tensorRefFromName and make this closer to the python backed.
                     }

type BM a = Backend.BuildT (StateT HState (State GState)) a

data Var = forall s t v. TensorFlow.Tensor.TensorKind v => Var (SShape s) (STyp t) (Backend.Tensor v (HaskType t))

initializedVariable :: forall s a. KnownShape s => KnownTyp a => T s a -> BM (Ref s a)
initializedVariable initVal = do
  BT i <- interpretPure initVal
  x <- lift (lift newId)
  v <- backendTensor (typeSTyp @a) $ Backend.initializedVariable i
  let var = (Var (typeSShape @s) (typeSTyp @a) v)
  lift (modify $ \HState{..} -> HState {genVars = IM.insert (fromIntegral x) var genVars,..})
  return (Ref (fromIntegral x) typeSShape typeSTyp )

placeholder :: forall s a. SShape s -> STyp a -> BM (Ref s a)
placeholder s t = do
  x <- lift (lift newId)
  ph <- backendTensor t $ Backend.placeholder (Backend.Shape (map convertNone $ shapeToList' s))
  let var = (Var s t ph)
  lift (modify $ \HState{..} -> HState {genVars = IM.insert (fromIntegral x) var genVars,..})
  return (Ref (fromIntegral x) s t )

interpGen :: Gen a -> BM a
interpGen (GPReturn x) = return x
interpGen (GPVariable _trainable _name initVal) = initializedVariable initVal
interpGen (GPPlaceholder s t _name) = placeholder s t
interpGen (GPModify _ _) = error "GPModify: TODO"
interpGen (GPState f) = lift (lift (state f))
interpGen (GPBind a b) = do x <- interpGen a
                            interpGen (b x)

listProxyLen :: forall proxy s. KnownLen s => proxy s -> Integer
listProxyLen _ = listTypeLen @s

-- genDistr :: forall s s0 t. KnownTyp t => Distribution s t -> SShape s0 -> SShape s -> DOC
-- genDistr d sh s1 = case d of
--   TruncatedNormalD stddev -> funcall "tf.truncated_normal"
--     [showSShape (sh .+. s1), named "stddev" (float stddev), named "dtype" (showTyp @t)]
--   UniformD low high -> funcall "tf.random_uniform" [showSShape (sh .+. s1)
--                                 ,named "minval" (float low)
--                                 ,named "maxval" (float high)
--                                 ,named "dtype" (showTyp @t)]
--   OrthogonalD ->
--     funcall' (funcall "tf.orthogonal_initializer" [named "dtype" (showTyp @t)]) [named "shape" (showSShape (sh .+. s1))]


knownNumeric :: forall t k. KnownNumeric t => (KnownTyp t => Num (HaskType t) => Backend.OneOf '[Backend.Int32, Float, Double] (HaskType t) => k) -> k
knownNumeric = knownNumeric' (typeSTyp @t)

knownNumeric' :: forall t k. KnownNumeric t => STyp t -> (KnownTyp t => Num (HaskType t) => Backend.OneOf '[Backend.Int32, Float, Double] (HaskType t) => k) -> k
knownNumeric' (STyp tk tb Refl) k = case tk of
  SFloat -> case tb of
    SB32 -> k
    SB64 -> k
  SBool -> error "TFNumeric bug"
  SInt -> case tb of
    SB32 -> k
    SB64 -> error "missing in tensorflow: int64 is not supported in matmul T_T"

knownFloatingB :: forall t k. (KnownTyp t, TypKind t ~ 'Float) => (Backend.OneOf '[Float, Double] (HaskType t) => k) -> k
knownFloatingB k = case bitsVal @(TypBits t) of
    SB32 -> k
    SB64 -> k

knownInt :: forall t k. (KnownTyp t, TypKind t ~ 'Int) => (Backend.OneOf '[Backend.Int32, Backend.Int64] (HaskType t) => k) -> k
knownInt k = case bitsVal @(TypBits t) of
    SB32 -> k
    SB64 -> k

backendTensor :: STyp t ->  (Backend.TensorType (HaskType t) => k) -> k
backendTensor (STyp SFloat SB32 Refl) k = k
backendTensor (STyp SInt SB64 Refl) k = k
backendTensor (STyp SBool _ Refl) k = k
backendTensor (STyp SFloat SB64 Refl) k = k
backendTensor (STyp SInt SB32 Refl) k = k

backendTensor' :: forall t k proxy. KnownTyp t => proxy t -> (Backend.TensorType (HaskType t) => k) -> k
backendTensor' _ = backendTensor (typeSTyp @t)


runUnOp :: forall s s1 t s2 u. KnownTyp u => KnownTyp t => BackendTensorType u => SShape s -> UnOp s1 t s2 u -> BT (s++s1) t -> BT (s++s2) u
runUnOp sL op (BT x) = backendTensor (typeSTyp @t) $ case op of
  SliceOp _ sR lo hi -> BT $ BackCore.slice x
    (shapeFromList (replicate (sListLen  sL) 0 ++ [lo] ++ replicate (sListLen sR) 0))
    (shapeFromList (shapeToList' sL ++ [hi-lo] ++ (shapeToList' sR)))
  Axis1Op aop -> case aop of
    (ArgMax _ _) -> knownNumeric @t $ knownInt @u $ BT $ BackCore.argMax x (Backend.scalar sLLen)
    (OneHot _) -> knownNumeric @u $ knownInt @t $  BT $ Backend.oneHot x (Backend.scalar sLLen) (Backend.scalar 1) (Backend.scalar 0)
    ReduceOp _ _sR rop -> knownNumeric @t $ case rop of
      Max -> BT $ BackCore.max x redindices
      Min -> BT $ BackCore.min x redindices
      Sum -> BT $ Backend.sum x redindices
      Mean -> BT $ Backend.mean x redindices
     where redindices = (Backend.vector [fromIntegral (sListLen sL) :: Backend.Int32 ])
  StopGradient -> BT $ BackCore.stopGradient x
  Cast -> BT $ Backend.cast x
  (Num1Op numop) -> knownNumeric @t $ case numop of
    Square -> BT (Backend.mul x x)
    Negate -> BT (Backend.neg x)
    Sign -> BT (Backend.sign x)
    Abs -> BT (Backend.abs x)
    FloorMod -> BT (Backend.floorMod x)
  Float1Op flop -> knownFloatingB @t $ knownFloating @(TypBits u) $ knownFloatingB @u $ case flop of
     Tanh -> BT (BackCore.tanh x)
     Sin -> BT (BackCore.sin x)
     Exp -> BT (BackCore.exp x)
     Sigmoid -> BT (BackCore.sigmoid x)
     Relu -> BT (BackCore.relu x)
     Floor -> BT (BackCore.floor x)
     Round -> BT (BackCore.round x)
     Cos -> BT (BackCore.cos x)
     Log -> BT (BackCore.log x)
     Asin -> BT (BackCore.asin x)
     Acos -> BT (BackCore.acos x)
     Sinh -> BT (BackCore.sinh x)
     Cosh -> BT (BackCore.cosh x)
     Asinh -> BT (BackCore.asinh x)
     Acosh -> BT (BackCore.acosh x)
     Atan -> BT (BackCore.atan x)
     Atanh -> BT (BackCore.atanh x)
     Sqrt -> BT (BackCore.sqrt x)
     HardSigmoid -> error "Haskell: no hard sigmoid defined yet"
     ClipByValue lo hi -> BT $ BackCore.clipByValue x (Backend.scalar $ realToFrac lo) (Backend.scalar $ realToFrac hi)
  Diag _ -> BT $ BackCore.batchMatrixDiag x
 where sLLen = fromIntegral (sListLen sL) :: Backend.Int32

interpretPure :: forall s t. KnownTyp t => KnownShape s => T s t -> BM (BT s t)
interpretPure x = do
  let sn = unsafePerformIO $ makeStableName x
  mv <- snMap22Lookup sn <$> lift (gets genPureTable)
  case mv of
    Just v -> return v
    Nothing -> do
      e  <- interpretPure' (\s x' -> knownSShape s $ interpretPure x') typeSShape x
      lift $ modify (\g -> g {genPureTable = (snMap22Insert (KV sn e)) (genPureTable g)})
      return e

interpNilOp :: forall s t. Backend.TensorType (HaskType t) => NilOp s t -> BM (BT s t)
interpNilOp = \case
  Constant c -> return $ BT $ Backend.scalar c
  Range n@Sat -> knownNumeric @t $ return $
    let start,limit,delta :: HaskType t
        start = 0
        limit = fromIntegral $ natVal n
        delta = 1
    in BT $ Backend.range (Backend.scalar start) (Backend.scalar limit) (Backend.scalar delta)
  Variable (Ref r sr tr) -> do
     tbl <- lift (gets genVars)
     case IM.lookup r tbl of
       Just (Var sx tx x) -> case (testEq sx sr, testEq tx tr) of
          (Just Refl, Just Refl) -> return (BT (Backend.expr x))
          _ -> error "panic: variable does not have the expected type"
       _ -> error "panic: variable not found" 

interpretPure' :: forall s t. KnownTyp t => (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> BM (BT s' t')) -> SShape s -> T s t -> BM (BT s t)
interpretPure' rec sR = knownSShape sR $ backendTensor (typeSTyp @t) $ \case
  Unbroadcast{} -> error "broadcasting operation did not complete!"
  DirectBroadcast s0 s1 s2 s3 x -> do
    BT recx <- rec (s0 .+. s2) x
    let expandedShape = shapeFromList
                          (concat [shapeToList' s0, genericReplicate (sListLength s1) 1
                                  ,shapeToList' s2, genericReplicate (sListLength s3) 1 ])
        targetShape = shapeFromList sR
    return $ BT $ BackCore.broadcastTo (Backend.reshape recx expandedShape) targetShape
   --  Noise noiseId s0 s1 x -> do
   --    return $ (genDistr x s0 s1) <+> (text "# " <> integer noiseId)
  T op -> interpNilOp op
  Where c x y -> do
    BT rc <- rec typeSShape c
    BT rx <- rec typeSShape x
    BT ry <- rec typeSShape y
    return $ BT $ BackCore.select rc rx ry
  UnOp operation s0 x -> do
    recx <- rec (s0 .+. unopInputShape operation) x
    return (runUnOp s0 operation recx)
  MatMul s0 a b c x y  -> do
    BT recx <- rec (s0 .+. a :* b :* Unit) x
    BT recy <- rec (s0 .+. b :* c :* Unit) y
    return $ knownNumeric @t $ BT $ BackCore.batchMatMul recx recy
  BinOp operation s0 s1 t s2 u x y -> knownSShape s0 $ knownSShape s1 $ knownSShape s2 $ knownProduct' s0 $ do
   BT recx <- rec (s0 .+. s1) x
   BT recy <- rec (s0 .+. s2) y
   let reshx = backendTensor t $ Backend.reshape recx (shapeVector (satProd s0 :* s1))
       reshy = backendTensor u $ Backend.reshape recy (shapeVector (satProd s0 :* s2))
   return $ case operation of
     Simple2Op sop  -> case sop of
        Add -> knownNumeric @t $ BT $ Backend.add recx recy
        Divide -> knownNumeric @t $ BT $ BackCore.div recx recy
        Equal -> backendTensor u $ BT $ Backend.equal recx recy
        Subtract -> knownNumeric @t $ BT $ Backend.sub recx recy
        Multiply -> knownNumeric @t $ BT $ Backend.mul recx recy
        Minimum -> knownNumeric @t $ BT $ BackCore.minimum recx recy
        Maximum -> knownNumeric @t $ BT $ BackCore.maximum recx recy
        LessThan ->  knownNumeric' u $ BT $ BackCore.less recx recy
     -- WTF moment: the arguments do not seem to be in the same order in python as in haskell
     -- python: https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits
     -- haskell: https://tensorflow.github.io/haskell/haddock/tensorflow-core-ops-0.2.0.0/TensorFlow-GenOps-Core.html#v:sparseSoftmaxCrossEntropyWithLogits
     SparseSoftmaxCrossEntropyWithLogits -> case t of
        STyp SInt SB32 Refl -> knownFloatingB @t $ BT $ fst $ BackCore.sparseSoftmaxCrossEntropyWithLogits reshy reshx
     SoftmaxCrossEntropyWithLogits -> knownFloatingB @t $ BT $ fst $ BackCore.softmaxCrossEntropyWithLogits reshy reshx
     -- SigmoidCrossEntropyWithLogits -> knownFloatingB @t $ BT $ Backend.sigmoidCrossEntropyWithLogits recy recx -- type is not as general as necessary
  ReshapeFrom s t -> do
    BT rt <- rec s t
    return $ BT $ BackCore.reshape rt (shapeVector sR)
  Concat s0 s1 xs -> do
    let go :: forall s0 s1 ns. SShape s0 -> SShape s1 -> NP (Catable s0 s1 t) ns -> BM [BackendTensor t]
        go _ _ Unit = return []
        go s0' s1' (Catable n y :* ys) = do
          BT y' <- rec (s0' .+. n :* s1') y
          (y' :) <$> go s0' s1' ys
    rxs <- go s0 s1 xs
    return $ BT $ Backend.concat (Backend.scalar (fromIntegral (sListLength s0))) rxs
  Transpose s p x -> do
    BT rx <- rec s x
    return $ BT $ Backend.transpose rx (permToTensor s p)
 --  Gather indexShape s0 m s1 x ix -> do
 --    rx <- rec (s0 .+. ((:*) m s1)) x
 --    rix <- rec indexShape ix
 --    return (func "tf.gather" [rx, rix] [])
 --  GatherND containerShape elementShape indexShape x ix -> do
 --    rx <- rec (containerShape .+. elementShape) x
 --    rix <- rec (indexShape *: (sListLenAsNat containerShape)) ix
 --    return (func "tf.gather_nd" [rx, rix] [])
  Convolution bs inChans outChans filterShape s0 x filters -> do
    BT recx <- rec (bs :* (s0 *: inChans)) x
    BT recFilters <- rec (filterShape .+. inChans :* outChans :* Unit) filters
    case filterShape of
       _width :* _height :* Unit ->
          return $ BT $ knownFloatingB @t $ BackCore.conv2D recx recFilters
       _ -> error "TypedFlow Haskell backend: convolution on an unsupported number of dims"
 --  Pool bs window typ numChans outSpatial x -> do
 --     rx <- rec ((:*) bs (zipWithMulSShapes window outSpatial .+. (:*) numChans Unit)) x
 --     return (func "tf.nn.pool"
 --                  [rx, showSShape window, typ', text (show ("SAME" :: String))]
 --                  [("strides", showSShape window)])
 --   where typ' = text $ (show $ case typ of MaxPool -> "MAX"; AvgPool -> "AVG" :: String)
 -- -- where rec :: forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> DOC
 -- --       rec = generatePure' 



================================================
FILE: TypedFlow/Layers/Core.hs
================================================
{-|
Module      : TypedFlow.Layers.Core
Description : Core layers and combinators.
Copyright   : (c) Jean-Philippe Bernardy, 2017
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental
-}
{-# LANGUAGE CPP #-}
#if __GLASGOW_HASKELL__ >= 806
{-# LANGUAGE NoStarIsType #-}
#endif
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeInType #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE PatternSynonyms #-}

module TypedFlow.Layers.Core
  (
    -- * Dense
    DenseP(..), dense, (#),
    -- * Dropout
    DropProb(..), mkMask, mkDropout, mkDropouts,
    -- * Embedding
    EmbeddingP(..), embedding, 
    -- * Convolutional
    ConvP(..), conv, conv', {-convValid,-} maxPool1D, maxPool2D,
    glu
  )

where
import Prelude hiding (RealFrac(..))
import GHC.TypeLits
import TypedFlow.TF
import TypedFlow.Types
import TypedFlow.Types.Proofs
import TypedFlow.Abstract
import Control.Monad.State (gets)
import Data.Monoid ((<>))
---------------------
-- Linear functions


-- | A dense layer is a linear function form a to b: a transformation matrix and a bias.
data DenseP t a b = DenseP {denseWeights :: Tensor '[a,b] t
                           ,denseBiases  :: Tensor '[b] t}

-----------------------
-- Feed-forward layers

-- | Parameters for the embedding layers
newtype EmbeddingP numObjects embeddingSize t = EmbeddingP (Tensor '[numObjects, embeddingSize] t)

instance (KnownNat numObjects, KnownTyp b, KnownNat embeddingSize) => KnownTensors (EmbeddingP numObjects embeddingSize b) where
  travTensor f s (EmbeddingP p) = EmbeddingP <$> travTensor f s p

instance (KnownNat numObjects, KnownBits b, KnownNat embeddingSize) => ParamWithDefault (EmbeddingP numObjects embeddingSize ('Typ 'Float b)) where
  defaultInitializer = EmbeddingP <$> (noise $ UniformD (-0.05) 0.05)

instance (KnownNat numObjects, KnownBits b, KnownNat embeddingSize) => ParamWithDefault (EmbeddingP numObjects embeddingSize ('Typ 'Cmplx b)) where
  defaultInitializer = EmbeddingP <$> (mkComplex <$> (noise $ UniformD (-0.05) 0.05) <*> (noise $ UniformD (-0.05) 0.05))

-- | embedding layer
embedding :: ∀ embeddingSize numObjects t. KnownNat embeddingSize => KnownNat numObjects =>
             EmbeddingP numObjects embeddingSize t -> Tensor '[] Int32 -> Tensor '[embeddingSize] t
embedding (EmbeddingP param) input = gather param input



instance (KnownNat a, KnownNat b, KnownTyp t) => KnownTensors (DenseP t a b) where
  travTensor f s (DenseP x y) = DenseP <$> travTensor f (s<>"_w") x <*> travTensor f (s<>"_bias") y

instance (KnownNat n, KnownNat m, KnownFloat b) => ParamWithDefault (DenseP b n m) where
  defaultInitializer = DenseP <$> glorotUniform <*> (noise $ TruncatedNormalD 0.1)

-- | Dense layer (Apply a linear function)
(#), dense :: ∀m n t. KnownNat n => KnownNat m => KnownNumeric t => DenseP t n m -> Tensor '[n] t -> Tensor '[m] t
(DenseP weightMatrix bias) # v = (weightMatrix ∙ v) + bias

dense = (#)

-- | A drop probability. (This type is used to make sure one does not
-- confuse keep probability and drop probability)
data DropProb = DropProb Float

-- | Generate a dropout function. The mask applied by the returned
-- function will be constant for any given call to mkDropout.  See
-- 'noise' for the sampling behaviour.
mkDropout :: forall s t. KnownShape s => KnownFloat t => DropProb -> Gen (Tensor s t -> Tensor s t)
mkDropout d = (⊙) <$> mkMask d

-- | Generate a 0-1 mask with given probability, suitable for dropout,
-- or all ones if not in training phase. See 'noise' for the sampling
-- behaviour.
mkMask :: forall s t. KnownShape s => KnownFloat t => DropProb -> Gen (Tensor s t)
mkMask (DropProb dropProb) = do
  let keepProb = 1 - dropProb
  let isTraining = genTrainingPlaceholder
  r <- noise $ UniformD keepProb (1 + keepProb)
  return $ if_ isTraining
               (floor r ⊘ constant (knownAlgebraic @t $ realToFrac keepProb))
               ones

newtype EndoTensor t s = EndoTensor (Tensor s t -> Tensor s t)

-- | Generate a dropout function for an heterogeneous tensor vector.
mkDropouts :: KnownFloat t => KnownLen shapes => All KnownShape shapes => DropProb -> Gen (HTV t shapes -> HTV t shapes)
mkDropouts d = appEndoTensor <$> mkDropouts' typeSList where
   mkDropouts' :: forall shapes t. KnownFloat t => All KnownShape shapes =>
                  SList shapes -> Gen (NP (EndoTensor t) shapes)
   mkDropouts' Unit = return Unit
   mkDropouts' (_ :* rest) = do
     x <- mkDropout d
     xs <- mkDropouts' rest
     return (EndoTensor x :* xs)

   appEndoTensor :: NP (EndoTensor t) s -> HTV t s -> HTV t s
   appEndoTensor Unit Unit = Unit
   appEndoTensor (EndoTensor f :* fs) (F x :* xs) = F (f x) :* appEndoTensor fs xs


------------------------
-- Convolutional layers

data ConvP t outChannels inChannels filterSpatialShape
  = ConvP (T (filterSpatialShape ++ '[inChannels,outChannels]) t)
          (T '[outChannels] t)

instance (KnownNat outChannels,KnownNat inChannels, KnownShape filterSpatialShape, KnownFloat t) =>
  ParamWithDefault (ConvP t outChannels inChannels filterSpatialShape) where
  defaultInitializer = prodHomo @filterSpatialShape @'[inChannels, outChannels] #>
                       prodAssoc @(Product filterSpatialShape) @inChannels @outChannels #>
                       knownAppend @filterSpatialShape @'[inChannels,outChannels] ?>
                       knownProduct @filterSpatialShape ?>
                       ConvP <$> (reshape <$> i) <*> pure (knownAlgebraic @t (constant 0.1))
    where i :: Gen (T '[Product filterSpatialShape*inChannels,outChannels] t)
          i = knownProduct @filterSpatialShape ?> glorotUniform

instance (KnownNat outChannels,KnownNat inChannels, KnownShape filterSpatialShape, KnownAlgebraic t) =>
  KnownTensors (ConvP t outChannels inChannels filterSpatialShape) where
  travTensor f s (ConvP x y) = knownAppend @filterSpatialShape @'[inChannels,outChannels] ?>
          (ConvP <$> travTensor f (s<>"_filters") x <*> travTensor f (s <> "_biases") y)

-- | Size-preserving convolution layer
conv' :: forall s outChannels filterSpatialShape inChannels t.
               KnownShape s => KnownNat inChannels => KnownNat outChannels => KnownShape filterSpatialShape => KnownAlgebraic t
            => Length filterSpatialShape <= 3
            => Length filterSpatialShape ~ Length s
            => ConvP t outChannels inChannels filterSpatialShape
            -> T (s ++ '[inChannels]) t
            -> T (s ++ '[outChannels]) t
conv' (ConvP filters bias) input = mapTT @s (+bias) (convolution @outChannels @filterSpatialShape @inChannels @s input filters)



conv :: forall outChannels filterSpatialShape inChannels s t.
               KnownShape s => KnownNat inChannels => KnownNat outChannels => KnownShape filterSpatialShape => KnownAlgebraic t
            => Length filterSpatialShape <= 3
            => (Length filterSpatialShape + 1) ~ Length s -- The ranks must match, but not necessarily the dimensions
            => (Last s ~ outChannels)
            => ConvP t outChannels inChannels filterSpatialShape
            -> T (Init s ++ '[inChannels]) t
            -> T s t
conv = initLast' @s #>
       incrPos @(Length filterSpatialShape) #>
       lengthInit (typeSList @s) #>
       incrCong @(Length filterSpatialShape) @(Length (Init s)) #>
       knownInit @s ?>
       conv' @(Init s)


-- -- | Convolution layers with no padding (applying the filter only on
-- -- positions where the input is fully defined, aka "VALID" in
-- -- tensorflow.)
-- convValid :: forall outChannels filterSpatialShape inChannels s t.
--                   ((1 + Length filterSpatialShape) ~ Length s,
--                    Length filterSpatialShape <= 3,
--                    KnownLen filterSpatialShape) -- the last dim of s is the batch size
--           => ConvP t outChannels inChannels filterSpatialShape -- ^ Parameters
--           -> T ('[inChannels] ++ AddSpatialDims s filterSpatialShape) ('Typ 'Float t) -- ^ input
--           -> (T ('[outChannels] ++ s) ('Typ 'Float t))
-- convValid (ConvP filters bias) input = convolutionValid input filters + bias

-- | Gated Linear Unit
-- See: Language Modeling with Gated Convolutional Networks
-- https://arxiv.org/pdf/1612.08083.pdf
glu :: forall n t. KnownBits t => KnownNat n => T '[n+n] ('Typ 'Float t) -> T '[n] ('Typ 'Float t)
glu x = plusMono @n @n #> knownPlus @n @n ?>
        let gate, h :: T '[n] ('Typ 'Float t)
            gate = slice0 @0 @n x
            h =  termCancelation @n @n #> slice0 @n @(n+n) x
        in sigmoid gate ⊙ h


================================================
FILE: TypedFlow/Layers/RNN/Attention.hs
================================================
{-|
Module      : TypedFlow.Layers.RNN.Attention
Description : Attention combinators to be used with RNN cells
Copyright   : (c) Jean-Philippe Bernardy, 2018
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental
-}

{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeInType #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE PatternSynonyms #-}

module TypedFlow.Layers.RNN.Attention (
  -- * Attention mechanisms
  -- ** Scoring functions
  AttentionScoring,
  multiplicativeScoring,
  AdditiveScoringP(..), additiveScoring,
  -- ** Attention functions
  AttentionFunction,
  uniformAttn,
  luongAttention,
  -- ** Attention combinators
  attentiveWithFeedback
  ) where

import Prelude hiding (RealFrac(..))
import GHC.TypeLits
import TypedFlow.TF
import TypedFlow.Types
import TypedFlow.Types.Proofs (appRUnit,(#>))
import TypedFlow.Layers.RNN.Base

-- | An attention scoring function. This function should produce a
-- score (between 0 and 1).
type AttentionScoring t keySize valueSize = 
  Tensor '[keySize] t -> Tensor '[valueSize] t -> Tensor '[] t

-- | A function which attends to an external input. Typically a
-- function of this type is a closure which has the attended input in
-- its environment. This environment is interpreted as an associative
-- memory form key to value.
type AttentionFunction t keySize valueSize =
  T '[keySize] t -> T '[valueSize] t

-- | @attnExample1 θ h st@ combines each element of the vector h with
-- s, and applies a dense layer with parameters θ. The "winning"
-- element of h (using softmax) is returned.
uniformAttn :: ∀ valueSize m keySize t. KnownNat valueSize => KnownNat m => KnownFloat t
       => AttentionScoring t keySize valueSize -- ^ scoring function
       -> T '[] Int32 -- ^ length of the input
       -> T '[m,valueSize] t -- ^ input (what we're attending to)
       -> AttentionFunction t keySize valueSize
uniformAttn score len hs key = c
  where xx,α :: T '[m] t
        xx = mapT (score key) hs
        α = softmax0 (mask ⊙ xx)
        c :: T '[valueSize] t
        c = hs ∙ α
        mask = cast (sequenceMask @m len) -- mask according to length

-- | Add some attention to an RnnCell, and feed the attention vector to
-- the next iteration in the rnn. (This follows the diagram at
-- https://github.com/tensorflow/nmt#background-on-the-attention-mechanism
-- commit 75aa22dfb159f10a1a5b4557777d9ff547c1975a).
attentiveWithFeedback ::forall attSize cellSize inputSize w ss. KnownNat inputSize => KnownNat attSize => KnownLen ss =>
  KnownTyp w =>
  AttentionFunction w cellSize attSize ->
  RnnCell w ss                   (T '[inputSize+attSize] w) (T '[cellSize] w) ->
  RnnCell w ('[attSize] ': ss)   (T '[inputSize        ] w) (T '[attSize] w)
attentiveWithFeedback attn cell = appRUnit @ss #> withFeedback (cell .-. timeDistribute attn)


-- -- | LSTM for an attention model. The result of attention is fed to the next step.
-- attentiveLstm :: forall attSize n x bs t. KnownNat bs =>
--   AttentionFunction t bs n attSize ->
--   LSTMP t n (x+attSize) ->
--   RnnCell t '[ '[attSize,bs], '[n,bs], '[n,bs] ] (Tensor '[x,bs] (Flt t)) (Tensor '[attSize,bs] (Flt t))
-- attentiveLstm att w = attentiveWithFeedback att (lstm w)


-- | Luong attention function (following
-- https://github.com/tensorflow/nmt#background-on-the-attention-mechanism
-- commit 75aa22dfb159f10a1a5b4557777d9ff547c1975a).
-- Essentially a dense layer with tanh activation, on top of uniform attention.
luongAttention :: ∀ attnSize d m e w. KnownNat e => KnownNat d => KnownNat attnSize => KnownNat m => KnownFloat w
  => Tensor '[d+e,attnSize] w     -- ^ weights for the dense layer
  -> AttentionScoring w e d -- ^ scoring function
  -> Tensor '[] Int32          -- ^ length of the input
  -> T '[m,d] w         -- ^ inputs
  -> AttentionFunction w e attnSize
luongAttention w scoring lens hs_ ht = 
  let ct = uniformAttn scoring lens hs_ ht
  in (tanh (w ∙ (concat0 ct ht)))

-- | Multiplicative scoring function
multiplicativeScoring :: forall valueSize keySize t.
  KnownFloat t => KnownNat valueSize => KnownNat keySize
  => T [keySize,valueSize] t -- ^ weights
  ->  AttentionScoring t keySize valueSize
multiplicativeScoring w dt h = ir · h
  where ir :: T '[valueSize] t
        ir = w ∙ dt


data AdditiveScoringP sz keySize valueSize t = AdditiveScoringP
  (Tensor '[1,sz]          t)
  (Tensor '[keySize, sz]   t)
  (Tensor '[valueSize, sz] t)

instance (KnownNat n, KnownNat k, KnownNat v, KnownTyp t) => KnownTensors (AdditiveScoringP k v n t) where
  travTensor f s (AdditiveScoringP x y z) = AdditiveScoringP <$> travTensor f (s<>"_v") x <*> travTensor f (s<>"_w1") y <*> travTensor f (s<>"_w2") z
instance (KnownNat n, KnownNat k, KnownNat v, KnownFloat t) => ParamWithDefault (AdditiveScoringP k v n t) where
  defaultInitializer = AdditiveScoringP <$> glorotUniform <*> glorotUniform <*> glorotUniform

-- | An additive scoring function. See https://arxiv.org/pdf/1412.7449.pdf
additiveScoring :: forall sz keySize valueSize t. KnownNat valueSize => KnownNat sz => KnownNat keySize => KnownFloat t =>
  AdditiveScoringP sz keySize valueSize t -> AttentionScoring t valueSize keySize
additiveScoring (AdditiveScoringP v w1 w2) dt h =  r''
  where w1h :: Tensor '[sz] t
        w1h = w1 ∙ h
        w2dt = w2 ∙ dt
        z' :: Tensor '[sz] t
        z' = tanh (w1h + w2dt)
        r'' = z' · squeeze0 v



================================================
FILE: TypedFlow/Layers/RNN/Base.hs
================================================
{-|
Module      : TypedFlow.Layers.RNN.Base
Description : RNN cells, layers and combinators.
Copyright   : (c) Jean-Philippe Bernardy, 2017
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental
-}

{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeInType #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE PatternSynonyms #-}

module TypedFlow.Layers.RNN.Base (
  -- * Cell Combinators
  RnnCell,
  simpleRnn,
  runCell, mkCell,
  stackRnnCells, (.-.),
  bothRnnCells, (.|.),
  withBypass, withFeedback,
  onStates,
  -- * Rnn Combinators
  Rnn,
  runRnn,
  stackRnns, (.--.),
  bothRnns,(.++.),
  -- * RNN unfolding functions
  timeDistribute,
  iterateCell,
  iterateCellBackward,
  iterateWithCull,
  -- * Monad-like interface for cell construction
  Component(..), bindC, returnC,
  -- rnnBackwardsWithCull,
  )

where
import Prelude hiding (tanh,Num(..),Floating(..),floor)
import GHC.TypeLits
import TypedFlow.TF
import TypedFlow.Types
import TypedFlow.Types.Proofs
-- import Data.Type.Equality
-- import Data.Kind (Type,Constraint)

-- | The RNN Component generalized monad. This can be used to build
-- RNNs cells which do not follow the simple and usual "stacking"
-- pattern. This is not a simple monad, because the indexing over
-- states is non-uniform; see 'BindC'.
newtype Component t (states::[Shape]) a
  = C {runC :: HTV t states -> (HTV t states , a)}
-- Note: states are tensors only, because we need to index into them
-- in the time dimension in iterateWithCull

instance Functor (Component t states) where
  fmap = mapC

mapC :: (a -> b) -> Component t s a -> Component t s b
mapC f c = C $ \s ->
  let (s',x) = runC c s
  in (s', f x)

-- | Unit of the Component monad.
returnC :: a -> Component t '[] a
returnC x = C $ \Unit -> (Unit,x)

-- | Bind operation for Components. States are accumulated.
bindC :: forall t s0 s1 a b. KnownLen s1
  => Component t s0 a -> (a -> Component t s1 b) -> Component t (s1++s0) b
bindC f g = C $ \(hsplit @s1 -> (s1,s0)) -> 
  let (s0',x) = runC f s0
      (s1',y) = runC (g x) s1
  in (happ s1' s0',y)

-- | A cell (one time-step) in an rnn. @state@ is the state propagated through time.
type RnnCell t states input output = input -> Component t states output

-- | An rnn. @n@ is the length of the time sequence. @state@ is the state propagated through time.
type Rnn n b state input output = RnnCell b state (V n input) (V n output) 

-- | Run a cell
runCell :: RnnCell t states input output -> (HTV t states,input) -> (HTV t states, output)
runCell cell = uncurry (flip (runC . cell))

-- | Run an RNN, using a tensor as input. @n@ is the length of the time sequence. 
runRnn :: (KnownNat n,KnownShape s0, KnownShape s1, KnownTyp t1)
       => Rnn n t2 states (T s1 t1) (T s0 t0)
       -> (HTV t2 states, Tensor (n ': s1) t1)
       -> (HTV t2 states, Tensor (n ': s0) t0)
runRnn l (s,x) =
  let x' = unstack0 x
      (s',y) = runCell l (s,x')
  in (s',stack0 y)

-- | Run an RNN composed of a single RNN cell.
simpleRnn :: KnownTyp t1 => KnownShape s1 => KnownShape s0 => KnownNat n
          => RnnCell t2 states (T s1 t1) (T s0 t0)
          -> (HTV t2 states, Tensor (n : s1) t1)
          -> (HTV t2 states, Tensor (n : s0) t0)
simpleRnn = runRnn . iterateCell

-- | Construct a cell from an arbitrary stateful function
mkCell :: ((HTV t states,input) -> (HTV t states, output)) -> RnnCell t states input output
mkCell cell = C . flip (curry cell)

----------------------
-- Lifting functions

-- | Convert a pure function (feed-forward layer) to an RNN cell by
-- ignoring the RNN state.
timeDistribute :: (a -> b) -> RnnCell t '[] a b
timeDistribute = constantOverSteps

-- | Convert a pure function (feed-forward layer) to an RNN cell by
-- ignoring the RNN state.
constantOverSteps :: (a -> b) -> RnnCell t '[] a b
constantOverSteps stateLess a = returnC (stateLess a)

--------------------------------------
-- Combinators

-- | Compose two rnn layers. This is useful for example to combine
-- forward and backward layers.
(.--.),stackRnns :: forall s1 s2 a b c n bits. KnownLen s2
  => Rnn n bits s1 a b -> Rnn n bits s2 b c -> Rnn n bits (s2 ++ s1) a c
stackRnns = stackRnnCells

infixr .--.
(.--.) = stackRnns

-- | Compose two rnn layers in parallel.
bothRnns,(.++.)  :: forall s1 s2 a b c n bits t.
  KnownTyp t => KnownLen s1 => KnownLen s2 => KnownNat n
  => KnownNat b => KnownNat c
  => Rnn n bits s1 a (T '[b] t) -> Rnn n bits s2 a (T '[c] t) -> Rnn n bits (s2 ++ s1) a (T ('[b+c]) t)
bothRnns f g x =
  f x `bindC` \y ->
  g x `bindC` \z ->
  returnC (concat0 <$> y <*> z)

infixr .++.
(.++.) = bothRnns

-- | Apply a function on the cell state(s) before running the cell itself.
onStates ::  (HTV t xs -> HTV t xs) -> RnnCell t xs a b -> RnnCell t xs a b
onStates f cell x = C $ \h -> do
  runC (cell x) (f h)

-- | Stack two RNN cells (LHS is run first)
stackRnnCells, (.-.) :: forall s0 s1 a b c t. KnownLen s1
  => RnnCell t s0 a b -> RnnCell t s1 b c -> RnnCell t (s1 ++ s0) a c
stackRnnCells l1 l2 x = l1 x `bindC` l2
(.-.) = stackRnnCells


-- | Compose two rnn cells in parallel.
bothRnnCells, (.|.) :: forall s0 s1 a b c t bits. KnownLen s0 => KnownLen s1
  => KnownBits bits
  => KnownNat b => KnownNat c
  => RnnCell t s0 a (T '[b] (Flt bits))
  -> RnnCell t s1 a (T '[c] (Flt bits))
  -> RnnCell t (s1 ++ s0) a (T '[b+c] (Flt bits))
bothRnnCells l1 l2 x  =
  l1 x `bindC` \y ->
  l2 x `bindC` \z ->
  returnC (concat0 y z)

(.|.) = bothRnnCells


-- | Run the cell, and forward the input to the output, by
-- concatenation with the output of the cell. This bypass is sometimes
-- called a 'highway' in the literature.
withBypass :: forall x y t b s0. KnownNat x => KnownNat y => KnownLen s0
  => KnownTyp t
  => RnnCell b s0 (T '[x] t) (T '[y] t) -> RnnCell b s0 (T '[x] t) (T '[x+y] t)
withBypass cell x = appRUnit @s0 #>
  cell x `bindC` \y ->
  returnC (concat0 x y)

-- | Run the cell, and feeds its output as input to the next time-step
withFeedback :: forall outputSize inputSize (w :: Typ) ss.
  KnownTyp w => KnownNat outputSize => KnownNat inputSize =>
  RnnCell w ss                    (T '[inputSize+outputSize] w) (T '[outputSize] w) ->
  RnnCell w ('[outputSize] ': ss) (T '[inputSize           ] w) (T '[outputSize] w)
withFeedback cell x = C $ \(F prevoutputnVector :* s) -> 
  let (s',y) = runC (cell (concat0 x prevoutputnVector)) s
  in  (F y :* s',y)

---------------------------------------------------------
-- RNN unfolding

-- | Build a RNN by repeating a cell @n@ times.
iterateCell :: ∀ n state input output b.
       (KnownNat n) =>
       RnnCell b state input output -> Rnn n b state input output
iterateCell c x = C $ \s -> chainForward (\(t,y) -> runC (c y) t) (s,x)

-- | Build a RNN by repeating a cell @n@ times. However the state is
-- propagated in the right-to-left direction (decreasing indices in
-- the time dimension of the input and output tensors)
iterateCellBackward :: ∀ n state input output b.
       (KnownNat n) =>
       RnnCell b state input output -> Rnn n b state input output
iterateCellBackward c x = C $ \s -> chainBackward (\(t,y) -> runC (c y) t) (s,x)

-- | RNN helper
chainForward :: ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (state , V n b)
chainForward _ (s0 , VUnit) = (s0 , VUnit)
chainForward f (s0 , x :** xs) = 
  let (s1,x') = f (s0 , x)
      (sFin,xs') = chainForward f (s1 , xs)
  in  (sFin,(x':**xs'))

-- | RNN helper
chainBackward :: ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (state , V n b)
chainBackward _ (s0 , VUnit) = (s0 , VUnit)
chainBackward f (s0 , (x:**xs)) =
  let (s1,xs') = chainBackward f (s0,xs)
      (sFin, x') = f (s1,x)
  in (sFin,(x':**xs'))


-- | RNN helper
chainForwardWithState :: ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (V n b, V n state)
chainForwardWithState _ (_s0 , VUnit) = (VUnit, VUnit)
chainForwardWithState f (s0 , (x:**xs)) =
  let (s1,x') = f (s0 , x)
      (xs',ss) = chainForwardWithState f (s1 , xs)
  in ((x':**xs'), (s1:**ss) )

-- -- | RNN helper
-- chainBackwardWithState ::
--   ∀ state a b n. ((state , a) -> (state , b)) → (state , V n a) -> (state , V n b, V n state)
-- chainBackwardWithState _ (s0 , VUnit) = return (s0 , VUnit, VUnit)
-- chainBackwardWithState f (s0 , (x:**xs)) = do
--   (s1,xs',ss') <- chainBackwardWithState f (s0,xs)
--   (sFin, x') <- f (s1,x)
--   return (sFin,(x':**xs'),(sFin:**ss'))

-- | RNN helper
transposeV :: forall n xs t. All KnownShape xs => KnownNat n =>
               SList xs -> V n (HTV t xs) -> HTV t (Ap (FMap (Cons n)) xs)
transposeV Unit _ = Unit
transposeV (_ :* n) xxs  = F ys' :* yys'
  where (ys,yys) = help @(Tail xs) xxs
        ys' = stack0 ys
        yys' = transposeV n yys
        help :: forall ys x tt. V n (HTV tt (x ': ys)) -> (V n (T x tt) , V n (HTV tt ys))
        help (xs) = ((fmap (fromF . hhead) xs),(fmap htail xs))

-- | @(gatherFinalStates dynLen states)[i] = states[dynLen[i]-1]@
gatherFinalStates :: KnownShape x => KnownNat n => T '[] Int32 -> T (n ': x) t -> T x t
gatherFinalStates dynLen states = gather states (dynLen ⊝ constant 1)

gathers :: forall n xs t. All KnownShape xs => KnownNat n =>
            SList xs -> T '[] Int32 -> HTV t (Ap (FMap (Cons n)) xs) -> HTV t xs
gathers Unit _ Unit = Unit
gathers (_ :* n) ixs (F x :* xs) = F (gatherFinalStates ixs x) :* gathers @n n ixs xs

-- | @rnnWithCull dynLen@ constructs an RNN as normal, but returns the
-- state after step @dynLen@ only.
iterateWithCull :: forall n x y ls b.
  KnownLen ls => KnownNat n => All KnownShape ls =>
  T '[] Int32 -- ^ dynamic length
  -> RnnCell b ls x y -> Rnn n b ls x y
iterateWithCull dynLen cell xs = C $ \s0 ->
  let (us,ss) = chainForwardWithState (uncurry (flip (runC . cell))) (s0,xs)
      sss = transposeV @n (typeSList @ls) ss
  in (gathers @n (typeSList @ls) dynLen sss,us)

-- -- | Like @rnnWithCull@, but states are threaded backwards.
-- rnnBackwardsWithCull :: forall n bs x y ls b.
--   KnownLen ls => KnownNat n => All KnownLen ls => All (LastEqual bs) ls =>
--   T '[bs] Int32 -> RnnCell b ls x y -> RNN n b ls x y
-- rnnBackwardsWithCull dynLen cell (s0, t) = do
--   (us,ss) <- chainBackwardWithState cell (s0,xs)
--   let sss = transposeV @n (shapeSList @ls) ss
--   return (gathers @n (shapeSList @ls) (n - dynLen) sss,us)


================================================
FILE: TypedFlow/Layers/RNN/Cells.hs
================================================
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnicodeSyntax #-}
{-|
Module      : TypedFlow.Layers.RNN.Cells
Description : RNN cells
Copyright   : (c) Jean-Philippe Bernardy, 2017
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental
-}


module TypedFlow.Layers.RNN.Cells (
  -- * RNN Cells
  cellInitializerBit,
  LSTMP(..),
  lstm,
  GRUP(..),
  gru,
  StackP(..),
  stackRU,
  ) where

import TypedFlow.Layers.RNN.Base
import TypedFlow.TF
import TypedFlow.Types
import TypedFlow.Types.Proofs
import GHC.TypeLits
import TypedFlow.Layers.Core (DenseP(..),(#))
import Prelude hiding (RealFrac(..))

--------------------------------------
-- Cells

-- | Standard RNN gate initializer. (The recurrent kernel is
-- orthogonal to avoid divergence; the input kernel is glorot)
cellInitializerBit :: ∀ n x t. (KnownNat n, KnownNat x, KnownFloat t) => Gen (DenseP t (n + x) n)
cellInitializerBit = DenseP <$> (concat0 <$> recurrentInitializer <*> kernelInitializer) <*> biasInitializer
  where recurrentInitializer :: Gen (Tensor '[n, n] t)
        recurrentInitializer = noise $ OrthogonalD
        kernelInitializer :: Gen (Tensor '[x, n] t)
        kernelInitializer = glorotUniform
        biasInitializer = pure zeros

-- | Parameter for an LSTM
data LSTMP t n x = LSTMP (DenseP t (n+x) n) (DenseP t (n+x) n) (DenseP t (n+x) n) (DenseP t (n+x) n)

instance (KnownNat n, KnownNat x, KnownFloat t) => KnownTensors (LSTMP t n x) where
  travTensor f s (LSTMP x y z w) = LSTMP <$> travTensor f (s<>"_f") x <*> travTensor f (s<>"_i") y <*> travTensor f (s<>"_c") z <*> travTensor f (s<>"_o") w
instance (KnownNat n, KnownNat x, KnownFloat t) => ParamWithDefault (LSTMP t n x) where
  defaultInitializer = LSTMP <$> forgetInit <*> cellInitializerBit <*> cellInitializerBit <*> cellInitializerBit
    where forgetInit = DenseP <$> (denseWeights <$> cellInitializerBit) <*> pure ones

-- | Standard LSTM
lstm :: ∀ n x t. KnownNat x => KnownNat n => KnownFloat t
  => LSTMP t n x -> RnnCell t '[ '[n], '[n]] (Tensor '[x] t) (Tensor '[n] t)
lstm (LSTMP wf wi wc wo) input = C $ \(VecPair ht1 ct1) -> 
  let f = sigmoid (wf # hx)
      hx = (concat0 ht1 input)
      i = sigmoid (wi # hx)
      cTilda = tanh (wc # hx)
      o = sigmoid (wo # hx)
      c = ((f ⊙ ct1) + (i ⊙ cTilda))
      h = (o ⊙ tanh c)
  in (VecPair h c, h)

-- | Parameter for a GRU
data GRUP t n x = GRUP (T [n+x,n] t) (T [n+x,n] t) (T [n+x,n] t)

instance (KnownNat n, KnownNat x, KnownFloat t) => KnownTensors (GRUP t n x) where
  travTensor f s (GRUP x y z) = GRUP <$> travTensor f (s<>"_z") x <*> travTensor f (s<>"_r") y <*> travTensor f (s<>"_w") z
instance (KnownNat n, KnownNat x, KnownFloat t) => ParamWithDefault (GRUP t n x) where
  defaultInitializer = GRUP <$> (denseWeights <$> cellInitializerBit) <*> (denseWeights <$> cellInitializerBit) <*> (denseWeights <$> cellInitializerBit)



-- | Standard GRU cell
gru :: ∀ n x t. KnownNat x => (KnownNat n, KnownFloat t) => GRUP t n x ->
        RnnCell t '[ '[n] ] (Tensor '[x] t) (Tensor '[n] t)
gru (GRUP wz wr w) xt = C $ \(VecSing ht1) ->
  let hx =  (concat0 ht1 xt)
      zt = sigmoid (wz ∙ hx)
      rt = sigmoid (wr ∙ hx)
      hTilda = tanh (w ∙ (concat0 (rt ⊙ ht1) xt))
      ht = ((ones ⊝ zt) ⊙ ht1 + zt ⊙ hTilda)
  in (VecSing ht, ht)


data StackP w n = StackP (DenseP w (n + n) 3)

defStackP :: KnownNat n => KnownFloat w => Gen (StackP w n)
defStackP = StackP <$> defaultInitializer
  -- (DenseP glorotUniform (stack0 (V [zeros, constant (-1), zeros]) )) -- demote popping a bit 

instance (KnownNat n, KnownTyp w) => KnownTensors (StackP w n) where
  travTensor f s (StackP d) = StackP <$> travTensor f s d

instance (KnownNat n, KnownFloat w) => (ParamWithDefault (StackP w n)) where
  defaultInitializer = defStackP

-- | A stack recurrent unit. The input has two purposes: 1. it is
-- saved in a stack. 2. it controls (a dense layer which gives) the
-- operation to apply on the stack.  The first type argument is the
-- depth of the stack.
stackRU :: ∀k n bs w. KnownNat k => KnownNat n => (KnownNat bs) => (KnownFloat w) => StackP w n ->
        RnnCell w '[ '[k+1,n]] (Tensor '[n] w) (Tensor '[n] w)
stackRU (StackP w) input = C $ \(VecSing st1) ->
  succPos @k #>
  plusMono @k @1 #>
  plusComm @k @1 #>
  termCancelation @k @1 #>
  let ct1 = nth0' @0 st1
      hx = concat0 ct1 input
      action :: T '[3] w
      action = softmax0 (w # hx)
      tl :: T '[k,n] w
      tl = slice0 @1 @(k+1) st1
      it :: T '[k,n] w
      it = slice0 @0 @k  st1
      stTilda :: T '[3,k+1,n] w
      stTilda = stack0 (st1 :**  (tl `concat0` zeros) :** (expandDim0 input `concat0` it) :** VUnit)
      st :: T '[k+1,n] w
      st = inflate2 (flatten12 stTilda ∙ action)
      ct = nth0' @0 st
  in (VecSing st, ct)



================================================
FILE: TypedFlow/Layers/RNN.hs
================================================
{-|
Module      : TypedFlow.Layers.RNN
Description : RNN cells, layers and combinators.
Copyright   : (c) Jean-Philippe Bernardy, 2017
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental
-}


module TypedFlow.Layers.RNN (
  module TypedFlow.Layers.RNN.Base,
  module TypedFlow.Layers.RNN.Cells,
  module TypedFlow.Layers.RNN.Attention)  where

import TypedFlow.Layers.RNN.Base
import TypedFlow.Layers.RNN.Cells
import TypedFlow.Layers.RNN.Attention


================================================
FILE: TypedFlow/Layers.hs
================================================

module TypedFlow.Layers
  (module  TypedFlow.Layers.Core
  ,module  TypedFlow.Layers.RNN
  ) where

import TypedFlow.Layers.Core
import TypedFlow.Layers.RNN



================================================
FILE: TypedFlow/Learn.hs
================================================
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE PatternSynonyms #-}
{-|
Module      : TypedFlow.Learn
Description : Loss functions and optimization strategies
Copyright   : (c) Jean-Philippe Bernardy, 2017
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental
-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE UnicodeSyntax #-}

module TypedFlow.Learn
  (-- losses:
    sparseCategorical, binary, timedCategorical, categoricalDistribution,sparseCategoricalDensePredictions,
    -- types
    Options(..), defaultOptions,
    Function(..),Model,ModelOutput,
    PreparedFunction(..), PreparedModel(..),
    -- other
    simpleModel, modelFunction, probeFunction,
    addRegularizer,
    prepare,
    -- utils
    placeholderName,
  ) where

import Data.Proxy
import TypedFlow.Types
import TypedFlow.Types.Proofs (knownAppend,  (?>), )
import TypedFlow.Broadcast (doBroadcast, mapPlaceHolders, ConsSh,doBroadcastSingle)
import TypedFlow.Abstract (doExtractVars)
import TypedFlow.TF
import Prelude hiding (RealFrac(..))
import GHC.TypeLits

-- | Triple of values that are always output in a model: prediction, loss and accuracy.
-- @t@ is the type of the prediction.
-- @s@ is the shape of the loss and accuracy
type ModelOutput t predictionShape s
  = Placeholders '[ '("loss",s,Float32) -- loss associated with the prediction
                  , '("accuracy",s,Float32)  -- is the prediction correct?
                  , '("y_",s++predictionShape,t) -- prediction (which can contain prediction-shaped info)
                  ]

pattern ModelOutput ::  T (s++predictionShape) t -> T s Float32 -> T s Float32 -> ModelOutput t predictionShape s
pattern ModelOutput y loss accur = PHT loss :* PHT accur :* PHT y :* Unit

-- | A standard modelling function: (input value, gold value) ↦ (prediction, accuracy, loss).
-- input is the shape of the input.
-- output is the shape of the output (one element per individual loss and accuracy)
-- p is the shape of each output element.
-- g is the shape of each gold output --- often equal to p.
type Model input tIn g p output tOut
  = T input tIn -> T (g++output) tOut -> ModelOutput tOut p output

-- | First type argument is the number of classes.  @categorical
-- logits gold@ return (prediction, accuraccy, loss)

sparseCategorical :: forall nCat. KnownNat nCat => Model '[nCat] Float32 '[] '[] '[] Int32
sparseCategorical logits y =
  let y_ = argmax0 logits
      modelCorrect = cast (equal y_ y)
      modelLoss = sparseSoftmaxCrossEntropyWithLogits y logits
  in ModelOutput y_ modelLoss modelCorrect

-- | First type argument is the number of classes.  @categorical
-- logits gold@ return (prediction, accuracy, loss)
sparseCategoricalDensePredictions :: forall nCat. KnownNat nCat
  => Tensor '[nCat] Float32
  -> Tensor '[] Int32
  -> ModelOutput  Float32 '[nCat] '[]
sparseCategoricalDensePredictions logits y =
  let y_ :: T '[nCat] Float32
      y_ = softmax0 logits
      modelCorrect = cast (equal (argmax0 logits) y)
      modelLoss = sparseSoftmaxCrossEntropyWithLogits y logits
  in ModelOutput y_ modelLoss modelCorrect


-- | First type argument is the number of classes.
-- @categoricalDistribution logits gold@ return (prediction,
-- accuraccy, loss) accuracy is reported as predicting the same class
-- as the input 'winning' class.
categoricalDistribution :: forall nCat. KnownNat nCat => Model '[nCat] Float32 '[nCat] '[nCat] '[] Float32
categoricalDistribution logits y =
  ModelOutput (softmax0 logits)
              (softmaxCrossEntropyWithLogits y logits)
              (cast (equal (argmax0 @'B32 logits) (argmax0 y)))
  

-- | @timedCategorical targetWeights logits y@
--
-- targetWeights: a zero-one matrix of the same size as
-- decoder_outputs. It is intended to mask padding positions outside
-- of the target sequence lengths with values 0.
--
-- Note that the accuracy is computed by multiplying the accuracies at
-- individual time steps with the targetWeights.

timedCategorical :: forall len nCat bits. KnownNat nCat => KnownNat len => KnownBits bits =>
  Tensor '[len] (Flt bits) -> Tensor '[len,nCat] (Flt bits) -> Tensor '[len] Int32 -> ModelOutput  (Flt bits) '[len,nCat] '[]
timedCategorical targetWeights logits y =
  let y_ :: Tensor '[len] Int32
      y_ = argmax1 logits
      modelY = softmax1 logits
      -- correct prediction for each position
      correctPrediction :: Tensor '[len] TFBool
      correctPrediction = equal y_ y
      -- total number of correct predictions
      correctPredictionWeighted :: Tensor '[] (Flt bits)
      correctPredictionWeighted = reduceSumAll (cast @(Flt bits) correctPrediction ⊙ targetWeights)
      weightSum = reduceSumAll targetWeights
      modelCorrect :: Tensor '[] Float32
      modelCorrect = cast (correctPredictionWeighted / weightSum)
      crossEntropies = zipWithT sparseSoftmaxCrossEntropyWithLogits y logits
      modelLoss = cast @Float32 (reduceSumAll (crossEntropies ⊙ targetWeights) / weightSum)
  in ModelOutput modelY modelLoss modelCorrect

-- | Model with @n@ binary outputs.
binary :: KnownNat n => Model '[n] Float32 '[] '[] '[n] Int32
binary logits y =
  let y_ = cast @Int32 (round sigy_)
      sigy_ = sigmoid logits
  in ModelOutput (y_)
                 (sigmoidCrossEntropyWithLogits (cast @Float32 y) logits)
                 (cast (equal y_ y))

-- | Model compiler options
data Options = Options {maxGradientNorm :: Maybe Prelude.Float -- ^ apply gradient clipping
                       }

-- | default model compiler options
defaultOptions :: Options
defaultOptions = Options {maxGradientNorm = Nothing}

type family Concatenate xs where
  Concatenate (x ': xs) = x ++ Concatenate xs
  Concatenate '[] = '[]

genPlaceholders :: All KnownPlaceholder shapesAndTypes => SList shapesAndTypes -> Placeholders shapesAndTypes
genPlaceholders Unit = Unit
genPlaceholders (ph :* names) = PHT (T (ExternalVar (Ref (placeholderName ph) typeSShape typeSTyp))) :* genPlaceholders names

placeholderName :: forall (ph :: PH)  p. KnownPlaceholder ph => p ph -> String
placeholderName proxy = refName (placeHolderRef proxy)

simpleModel :: forall p sx tx sy ty sy_ ty_.
               (KnownShape sy_, KnownShape p, KnownShape sx, KnownTyp ty_, KnownShape sy, KnownTyp tx, KnownTyp ty)
            => (Tensor sx tx -> Tensor sy ty -> ModelOutput  ty_ p sy_)
            -> Function 
simpleModel f = knownAppend @sy_ @p ?> modelFunction "runModel" f'
  where f' :: Placeholders '[ '("x",sx,tx), '("y",sy,ty)] -> ModelOutput ty_ p sy_
        f' (PHT x :* PHT y :* Unit) = f x y


-- | Add a term to the loss. This function is intendend to add
-- regularizers, ie. losses that do not depend on the predicted
-- output, but rather on the structure of a parameter.
addRegularizer :: Scalar Float32 -> Gen ()
addRegularizer r = GPState  $ \GState{..} -> ((),GState{genRegularizers=r:genRegularizers,..})


       
knownBatchModel :: forall n ps. KnownNat n => NP (Sat KnownPlaceholder) ps -> NP (Sat KnownPlaceholder) (Ap (FMap (ConsSh n)) ps)
knownBatchModel Unit = Unit
knownBatchModel (Comp Dict :* xs) = Sat :* knownBatchModel @n xs

-- | take the mean of loss/accur over the batch, etc. and add regulariser to loss
consolidate :: forall s rest. KnownShape s
            => Scalar Float32
            -> Placeholders ( '("loss",s  ,Float32) ': '("accuracy",s  ,Float32) ': rest)
            -> Placeholders ( '("loss",'[],Float32) ': '("accuracy",'[],Float32) ': rest)
consolidate extraLoss (PHT loss :* PHT accur :* rest) = (PHT (reduceMeanAll loss + extraLoss) :* PHT (reduceMeanAll accur) :* rest)

class (All KnownPlaceholder ps, KnownLen ps) => KnownPHS ps
instance (All KnownPlaceholder ps, KnownLen ps) => KnownPHS ps

data PreparedFunction = PreparedFunction {pfName :: String,
                                          pfBatched :: Bool,
                                          pfInputs, pfOutputs :: SomeSuch KnownPHS Placeholders}
data PreparedModel = PreparedModel {pmBatchSize :: Integer,
                                    pmParams :: [VarInfo],
                                    pmFunctions :: [PreparedFunction]
                                   }

-- | Prepare compilation of a model by:
-- extracting and exposing parameters 
-- batching the model
-- exposing placeholders
-- consolidating loss and accuracy
-- adding regularizers to the loss
prepare :: forall bs. (KnownNat bs)
        => Gen [Function]
        -> PreparedModel
prepare fGen =
  PreparedModel
    {pmBatchSize = natVal (Proxy @bs)
    ,pmParams = [VarInfo{varInitial=fmap doBroadcastSingle varInitial,..} | VarInfo{..} <- filter varTrainable vars]
    ,pmFunctions = flip map fs $ \case
        ModelFn nm st1 st2 f ->
          knownAll (knownBatchModel @bs st1) $
          knownAll (knownBatchModel @bs st2) $
          knownAll st1 $ 
          knownAll st2 $ 
          let placeHolders = genPlaceholders typeSList
              u = -777 -- magic unique identifier for the batch dimension
          in PreparedFunction nm
               True
               (SomeSuch placeHolders)
               (SomeSuch $ doBroadcast (consolidate {-@(bs ': s) @(BPH bs st2)-} regular (mapPlaceHolders @bs u True f placeHolders)))
        ProbeFn nm st1 st2 f -> 
          knownAll st1 $
          knownAll st2 $
          let placeHolders = genPlaceholders typeSList
          in PreparedFunction nm False (SomeSuch placeHolders) (SomeSuch (doBroadcast (f placeHolders)))
    }
  where (fs,finalState,vars) = doExtractVars fGen
        regular = sum (genRegularizers finalState)

data Function where
  ModelFn :: (KnownShape s, KnownLen st1, KnownLen st2)
          => String
          -> NP (Sat KnownPlaceholder) st1 -> NP (Sat KnownPlaceholder) st2 
          -> (Placeholders st1 -> Placeholders ('("loss",s,Float32) ': '("accuracy",s,Float32) ': st2)) -> Function
  ProbeFn :: (KnownLen st1, KnownLen st2, All KnownPlaceholder st1, All KnownPlaceholder st2)
          => String
          -> NP (Sat KnownPlaceholder) st1 -> NP (Sat KnownPlaceholder) st2 
          -> (Placeholders st1 -> Placeholders st2) -> Function

modelFunction :: (KnownShape s, KnownLen st1, KnownLen st2, All KnownPlaceholder st1, All KnownPlaceholder st2)
          => String
          -> (Placeholders st1 -> Placeholders ('("loss",s,Float32) ': '("accuracy",s,Float32) ': st2)) -> Function
modelFunction nm f = ModelFn nm (allKnown @KnownPlaceholder) (allKnown @KnownPlaceholder) f


probeFunction :: (KnownLen st1, KnownLen st2, All KnownPlaceholder st1, All KnownPlaceholder st2)
          => String
          -> (Placeholders st1 -> Placeholders st2) -> Function
probeFunction nm f = ProbeFn nm (allKnown @KnownPlaceholder) (allKnown @KnownPlaceholder) f




================================================
FILE: TypedFlow/Memo.hs
================================================
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}
module TypedFlow.Memo where

import qualified Data.IntMap as I
import qualified Data.Map.Strict as M
import System.Mem.StableName
import Data.IORef
import System.IO.Unsafe
import Unsafe.Coerce
import Data.Kind (Type)
type SNMap k v = I.IntMap [(StableName k,v)]

snMapLookup :: StableName k -> SNMap k v -> Maybe v
snMapLookup sn m = do
  x <- I.lookup (hashStableName sn) m
  lookup sn x

snMapInsert :: StableName k -> v -> SNMap k v -> SNMap k v
snMapInsert sn res = I.insertWith (++) (hashStableName sn) [(sn,res)]

memo :: (a -> b) -> a -> b
memo f = unsafePerformIO (
  do { tref <- newIORef (I.empty)
     ; return (applyStable f tref)
     })

applyStable :: (a -> b) -> IORef (SNMap a b) -> a -> b
applyStable f tbl arg = unsafePerformIO (
  do { sn <- makeStableName arg
     ; lkp <- snMapLookup sn <$> readIORef tbl
     ; case lkp of
         Just result -> return result
         Nothing ->
           do { let res = f arg
              ; modifyIORef tbl (snMapInsert sn res)
              ; return res
              }})

memoOrd :: Ord a => (a -> b) -> a -> b
memoOrd f = unsafePerformIO (
  do { tref <- newIORef (M.empty)
     ; return (applyStableOrd f tref)
     })

applyStableOrd :: Ord a => (a -> b) -> IORef (M.Map a b) -> a -> b
applyStableOrd f tbl arg = unsafePerformIO (
  do { lkp <- M.lookup arg <$> readIORef tbl
     ; case lkp of
         Just result -> return result
         Nothing ->
           do { let res = f arg
              ; modifyIORef tbl (M.insert arg res)
              ; return res
              }})


data Some2 k1 k2 (f :: k1 -> k2 -> Type) where
  Some2 :: forall k1 k2 f a b. StableName (f a b) -> Some2 k1 k2 f

instance Eq (Some2 k1 k2 f) where
  Some2 sn1 == Some2 sn2 = eqStableName sn1 sn2

type SSNMap2 k1 k2 (f :: k1 -> k2 -> Type) v = I.IntMap [(Some2 k1 k2 f,v)]

makeSn2 :: f a b -> Some2 k1 k2 f
makeSn2 = Some2 . unsafePerformIO . makeStableName

snMapLookup2 :: Some2 k1 k2 f -> SSNMap2 k1 k2 f v -> Maybe v
snMapLookup2 (Some2 sn) m = do
  x <- I.lookup (hashStableName sn) m
  lookup (Some2 sn) x

snMapInsert2 :: Some2 k1 k2 f -> v -> SSNMap2 k1 k2 f v -> SSNMap2 k1 k2 f v
snMapInsert2 (Some2 sn) res = I.insertWith (++) (hashStableName sn) [(Some2 sn,res)]

data KV k1 k2 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type)  where
  KV :: forall k1 k2 f v a b. StableName (f a b) -> v a b -> KV k1 k2 f v

type SNMap22 k1 k2 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) = I.IntMap [KV k1 k2 f v]

snMap22Lookup :: StableName (f a b) -> SNMap22 k1 k2 f v -> Maybe (v a b)
snMap22Lookup sn  m = do
  x <- I.lookup (hashStableName sn) m
  lkKV sn x

lkKV :: StableName (f a b) -> [KV k1 k2 f v] -> Maybe (v a b)
lkKV _ [] = Nothing
lkKV sn (KV sn' v:kvs) | eqStableName sn sn' = Just (unsafeCoerce v) -- sn == sn' -> a == a' and b == b' 
                       | otherwise = lkKV sn kvs

snMap22Insert :: KV k1 k2 f v -> SNMap22 k1 k2 f v -> SNMap22 k1 k2 f v
snMap22Insert (KV sn res) = I.insertWith (++) (hashStableName sn) [KV sn res]


-- | The type of a memo table for functions of a.
type Memo a = forall r. (a -> r) -> (a -> r)

-- | Memoize a two argument function (just apply the table directly for
-- single argument functions).
memo2 :: Memo a -> Memo b -> (a -> b -> r) -> (a -> b -> r)
memo2 a b = a . (b .)

-- | Memoize a three argument function.
memo3 :: Memo a -> Memo b -> Memo c -> (a -> b -> c -> r) -> (a -> b -> c -> r)
memo3 a b c = a . (memo2 b c .)


================================================
FILE: TypedFlow/Memo2.hs
================================================
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}

module TypedFlow.Memo2 where

import Data.Kind (Type)
import qualified Data.Map.Strict as M
import System.Mem.StableName
-- import Data.IORef
-- import System.IO.Unsafe
import Unsafe.Coerce
import qualified Data.IntMap as I
import Data.Type.Equality
import Control.Monad.IO.Class
import Data.IORef
import TypedFlow.Types.Proofs (SingEq(..))
import Data.List (intercalate)

data Map0 k (m :: Type -> Type) f v = forall . Map0 {
  m0Key :: f -> IO k,
  m0Empty :: m v,
  m0lk  :: k -> m v -> Maybe v,
  m0upd :: k -> (Maybe v -> v) -> m v -> m v,
  m0fmap :: forall u w.  (u -> w) -> m u -> m w,
  m0showKey :: k -> String,
  m0showTbl :: (v -> String) -> (m v -> String)
  }


data Map1 (k :: k1 -> Type) (m :: (k1 -> Type) -> Type)  (f :: k1 -> Type) (v :: k1 -> Type) = Map1 {
  m1Key :: forall x. f x -> IO (k x),
  m1Empty :: m v,
  m1lk  :: forall x. k x -> m v -> Maybe (v x),
  m1upd :: forall x. k x -> (Maybe (v x) -> (v x)) -> m v -> m v,
  m1showKey :: forall x . k x -> String,
  m1showTbl :: (forall x . v x -> String) -> (m v -> String)
  }

data Map2 (k :: k1 -> k2 -> Type) (m :: (k1 -> k2 -> Type) -> Type)  (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) = Map2 {
  m2Key :: forall x y. f x y -> IO (k x y),
  m2Empty :: m v,
  m2lk  :: forall x y. k x y -> m v -> Maybe (v x y),
  m2upd :: forall x y. k x y -> (Maybe (v x y) -> (v x y)) -> m v -> m v,
  -- m2fmap :: forall u w.  (forall x y. u x y -> w x y) -> m u -> m w,
  m2showKey :: forall x y. k x y -> String,
  m2showTbl :: (forall x y. v x y -> String) -> (m v -> String)
  }

data Map3 (k :: k1 -> k2 -> k3 -> Type) (m :: (k1 -> k2 -> k3 -> Type) -> Type)  (f :: k1 -> k2 -> k3 -> Type) (v :: k1 -> k2 -> k3 -> Type) = Map3 {
  m3Key :: forall x y z. f x y z -> IO (k x y z),
  m3Empty :: m v,
  m3lk  :: forall x y z. k x y z -> m v -> Maybe (v x y z),
  m3upd :: forall x y z. k x y z -> (Maybe (v x y z) -> (v x y z)) -> m v -> m v,
  m3showKey :: forall x y z. k x y z -> String,
  m3showTbl :: (forall x y z. v x y z -> String) -> (m v -> String)
  }

newtype Id x = Id x

ordMap :: forall k b. (Ord k, Show k) => Map0 k (M.Map k) k b
ordMap = Map0 {..} where
  m0Key = return
  m0Empty = mempty
  m0lk k = M.lookup k
  m0upd k f m = M.alter (Just . f) k m
  m0fmap = fmap
  m0showKey = show
  m0showTbl sh m = intercalate ";" [(show k) <> "↦" <> (sh v) | (k,v) <- M.assocs m]

data Single1 f g where
  None1 :: Single1 f g
  Single1 :: f a -> g a -> Single1 f g 

verifMap1 :: forall k v. SingEq k => Map1 k (Single1 k) k v
verifMap1 = Map1 {..} where
  m1Key = return
  m1Empty = None1
  m1lk :: k a -> Single1 k b -> Maybe (b a)
  m1lk k = \case
    None1 -> Nothing
    Single1 k' v -> case testEq k k' of
      Just Refl -> Just v
      Nothing -> error "verifMap1: mismatching keys! (1)"
  m1upd :: forall x. k x -> (Maybe (v x) -> (v x)) -> Single1 k v -> Single1 k v
  m1upd k f None1 = Single1 k (f Nothing)
  m1upd k f (Single1 k' v) = case testEq k k' of
      Just Refl -> Single1 k (f (Just v))
      Nothing -> error "verifMap1: mismatching keys! (2)"
  m1showKey _ = "#"
  m1showTbl :: (forall x . v x -> String) -> (Single1 k v -> String)
  m1showTbl _ None1 = "·"
  m1showTbl h (Single1 _ v) = "!" <> (h v)


testStable :: StableName a -> StableName b -> Maybe (a :~: b)
testStable sn sn' | eqStableName sn sn' = Just (unsafeCoerce Refl)
                  | otherwise = Nothing

snMap2 :: forall f v. Map2 (SN2 f) (SNMap22 f) f v
snMap2 = Map2 {..} where
  m2showTbl :: (forall x y. v x y -> String) -> (SNMap22 f v -> String)
  m2showTbl h (SNMap22 m) = intercalate "," [ m2showKey k <> "↦" <> h v | e <- I.elems m, KV k v <- e   ]
  m2showKey (SN2 sn) = show (hashStableName sn)
  m2Key obj = SN2 <$> makeStableName obj
  m2Empty = mempty
  m2lk = snMap22Lookup
  m2upd :: SN2 f x y -> (Maybe (v x y) -> (v x y)) -> SNMap22 f v -> SNMap22 f v
  m2upd (SN2 sn) f (SNMap22 m) = SNMap22 $
                                 I.alter (\case Nothing -> Just [KV (SN2 sn) (f Nothing)]
                                                Just p -> Just (updKV (SN2 sn) f p))
                                 (hashStableName sn)
                                 m

  updKV :: SN2 f' x y -> (Maybe (v' x y) -> (v' x y)) -> [KV k1 k2 (SN2 f') v'] -> [KV k1 k2 (SN2 f') v']
  updKV (SN2 sn) f [] = [KV (SN2 sn) (f Nothing)]
  updKV (SN2 sn) f (v@(KV (SN2 sn') x):xs) = case testStable sn sn' of
    Just Refl -> KV (SN2 sn') (f (Just x)):xs
    Nothing -> v : updKV (SN2 sn) f xs
                                 
  -- m2fmap :: forall u w.  (forall x y. u x y -> w x y) -> SNMap22 f u -> SNMap22 f w
  -- m2fmap h (SNMap22 t) = SNMap22 (fmap (fmap (\(KV k v) -> KV k (h v))) t)

  snMap22Lookup :: forall a b f' v'. SN2 f' a b -> SNMap22 f' v' -> Maybe (v' a b)
  snMap22Lookup (SN2 sn) (SNMap22 m) = do
    x <- I.lookup (hashStableName sn) m
    lkKV sn x

  lkKV :: forall k1 k2 f' v' a b . StableName (f' a b) -> [KV k1 k2 (SN2 f') v'] -> Maybe (v' a b)
  lkKV _ [] = Nothing
  lkKV sn (KV (SN2 sn') v:kvs) = case testStable sn sn' of
                             Just Refl ->  Just (unsafeCoerce v) -- sn == sn' -> a == a' and b == b' 
                             Nothing ->  lkKV sn kvs


data KV k1 k2 (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type)  where
  KV :: forall k1 k2 f v a b. f a b -> v a b -> KV k1 k2 f v

newtype SNMap22  (f :: k1 -> k2 -> Type) (v :: k1 -> k2 -> Type) = SNMap22 (I.IntMap [KV k1 k2 (SN2 f) v]) deriving (Monoid, Semigroup)

newtype SN2 (f :: k1 -> k2 -> Type) a b = SN2 (StableName (f a b)) 

data (:.:) (m1 :: k2 -> Type) (m2 :: k1 -> k2) (h :: k1) = Comp (m1 (m2 h))


data Sig02 f g x y where
  Ex02 :: f -> g x y -> Sig02 f g x y

data Sig03 f g x y z where
  Ex03 :: f -> g x y z -> Sig03 f g x y z

data Sig12 f g x y z where
  Ex12 :: f x -> g y z -> Sig12 f g x y z

data Sig22 f g x y where
  Ex22 :: f x y -> g x y -> Sig22 f g x y

data P33 f g x y z where
  T33 :: f x y z -> g x y z -> P33 f g x y z



containing00 :: (forall v. Map0 k1 m1 f v) -> Map0 k2 m2 g h -> Map0 (k1,k2) (m1 :.: m2)  (f,g) h
containing00 f g  = Map0
   {
   m0Key = (\(a,b) -> (,) <$> m0Key f a <*> m0Key g b),
   m0Empty = Comp (m0Empty f),
   m0lk = \(k1,k2) (Comp t) -> do t' <- m0lk f k1 t; m0lk g k2 t',
   m0upd = \(k1,k2) h (Comp t) -> Comp $ m0upd f k1 (m0upd g k2 h . \case Just tb -> tb; Nothing -> (m0Empty g)) t,
   m0fmap = \h (Comp t) -> Comp $ m0fmap f (m0fmap g h) t,
   m0showKey = \(k1,k0) -> m0showKey f k1 <> "," <> m0showKey g k0,
   m0showTbl = \h (Comp t) -> m0showTbl f (m0showTbl g h) t
   }                      

containing02 :: (forall v. Map0 k1 m1 f v) -> Map2 k2 m2 g h -> Map2 (Sig02 k1 k2) (m1 :.: m2) (Sig02 f g)  h
containing02 f g = Map2
   {
   m2Key = (\(Ex02 a b) -> Ex02 <$> m0Key f a <*> m2Key g b),
   m2Empty = Comp (m0Empty f),
   m2lk = \(Ex02 k1 k2) (Comp t) -> do t' <- m0lk f k1 t; m2lk g k2 t',
   m2upd = \(Ex02 k1 k2) h (Comp t) -> Comp $ m0upd f k1 (m2upd g k2 h . \case Just tb -> tb; Nothing -> (m2Empty g)) t,
   -- m2fmap = \h (Comp t) -> Comp $ m0fmap f (m2fmap g h) t,
   m2showKey = \(Ex02 k1 k2) -> m0showKey f k1 <> "," <> m2showKey g k2,
   m2showTbl = \h (Comp t) -> m0showTbl f (m2showTbl g h) t
   }                      

containing03 :: (forall v. Map0 k1 m1 f v) -> Map3 k2 m2 g h -> Map3 (Sig03 k1 k2) (m1 :.: m2) (Sig03 f g)  h
containing03 f g = Map3
   {
   m3Key = (\(Ex03 a b) -> Ex03 <$> m0Key f a <*> m3Key g b),
   m3Empty = Comp (m0Empty f),
   m3lk = \(Ex03 k1 k3) (Comp t) -> do t' <- m0lk f k1 t; m3lk g k3 t',
   m3upd = \(Ex03 k1 k3) h (Comp t) -> Comp $ m0upd f k1 (m3upd g k3 h . \case Just tb -> tb; Nothing -> (m3Empty g)) t,
   m3showKey = \(Ex03 k1 k2) -> m0showKey f k1 <> "," <> m3showKey g k2
,
   m3showTbl = \h (Comp t) -> m0showTbl f (m3showTbl g h) t
  }                      

newtype Lam' (m2 :: (k2 -> k3 -> Type) -> Type) (h :: k1 -> k2 -> k3 -> Type) (a :: k1) = Lam' {fromLam' :: (m2 (h a))}
data M12 (m1 :: (k1 -> Type) -> Type) (m2 :: (k2 -> k3 -> Type) -> Type) (h :: k1 -> k2 -> k3 -> Type) = M12 (m1 (Lam' m2 h))

containing12 :: (forall v. Map1 k1 m1 f v) -> (forall k4. Map2 k2 m2 g (h k4)) -> Map3 (Sig12 k1 k2) (M12 m1 m2) (Sig12 f g) h
containing12 f g = Map3
   {
   m3Key = (\(Ex12 a b) -> Ex12 <$> m1Key f a <*> m2Key g b),
   m3Empty = M12 (m1Empty f),
   m3lk = \(Ex12 k1 k2) (M12 t) -> do Lam' t' <- m1lk f k1 t; m2lk g k2 t',
   m3upd = \(Ex12 k1 k2) h (M12 t) -> M12 $ m1upd f k1 (Lam' . m2upd g k2 h . (\case Just tb -> tb; Nothing -> m2Empty g) . fmap fromLam') t,
   m3showKey = \(Ex12 k1 k2) -> m1showKey f k1 <> ">" <> m2showKey g k2,
   m3showTbl = \h (M12 t) -> m1showTbl f (m2showTbl g h . fromLam') t
   }



data F2m m g h = F2m (forall x y. g x y -> m (h x y))
data F2m' m g f h = F2m' (forall x y. g x y -> f x y -> m (h x y))

data F3m m g h = F3m (forall x y z. g x y z -> m (h x y z))
data F3m' m g f h = F3m' (forall x y z. g x y z -> f x y z -> m (h x y z))

memo2 :: forall g h k m n. MonadIO n => Map2 k m g h -> ((forall x y. g x y -> n (h x y)) -> forall x y. g x y -> n (h x y)) -> n (F2m n g h)
memo2 Map2{..} f = do
    tblRef <- liftIO $ newIORef m2Empty
    let finished :: forall x y. g x y -> n (h x y)
        finished arg = do
          tbl <- liftIO $ readIORef tblRef
          key <- liftIO $ m2Key arg
          case m2lk key tbl of
            Just result -> do
              -- liftIO $ putStrLn "memo2: hit"
              return result
            Nothing -> do
              -- liftIO $ putStrLn "memo2: miss"
              res <- f finished arg
              liftIO $ modifyIORef tblRef (m2upd key $ \_ -> res)
              return res
    return (F2m finished)
  
memo2' :: forall g f h k m n. MonadIO n => Map2 k m g h -> ((forall x y. g x y -> f x y -> n (h x y)) -> forall x y. g x y -> f x y -> n (h x y)) -> n (F2m' n g f h)
memo2' Map2{..} f = do
    tblRef <- liftIO $ newIORef m2Empty
    let finished :: forall x y. g x y -> f x y -> n (h x y)
        finished arg extra = do
          tbl <- liftIO $ readIORef tblRef
          key <- liftIO $ m2Key arg
          case m2lk key tbl of
            Just result -> do
              -- liftIO $ putStrLn "memo2': hit"
              return result
            Nothing -> do
              -- liftIO $ putStrLn ("memo2: miss " <> m2showKey key) --  <> " from " <> m3showTbl (const ".") tbl
              res <- f finished arg extra
              liftIO $ modifyIORef tblRef (m2upd key $ \_ -> res)
              return res
    return (F2m' finished)

memo3' :: forall g f h k m n. MonadIO n => Map3 k m g h -> ((forall x y z. g x y z -> f x y z -> n (h x y z)) -> forall x y z. g x y z -> f x y z -> n (h x y z)) -> n (F3m' n g f h)
memo3' Map3{..} f = do
    tblRef <- liftIO $ newIORef m3Empty
    let finished :: forall x y z. g x y z -> f x y z -> n (h x y z)
        finished arg extra = do
          tbl <- liftIO $ readIORef tblRef
          key <- liftIO $ m3Key arg
          case m3lk key tbl of
            Just result -> do
              -- liftIO $ putStrLn "memo3: hit"
              return result
            Nothing -> do
              -- liftIO $ putStrLn ("memo3: miss " <> m3showKey key) --  <> " from " <> m3showTbl (const ".") tbl
              res <- f finished arg extra
              liftIO $ modifyIORef tblRef (m3upd key $ \_ -> res)
              return res
    return (F3m' finished)


================================================
FILE: TypedFlow/Models/Topic.hs
================================================
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UnicodeSyntax #-}
{-|
Module      : TypedFlow.Models.Topic
Description : Topic models
Copyright   : (c) Jean-Philippe Bernardy, 2017
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental
-}


module TypedFlow.Models.Topic where
import Prelude hiding (RealFrac(..))
import TypedFlow.TF
import TypedFlow.Layers
import TypedFlow.Types
import TypedFlow.Types.Proofs ((?>), knownSum')
import TypedFlow.Learn
import GHC.TypeLits
import Data.Monoid ((<>))
import Data.Proxy

-- | A convolutional document summary function. Described in
-- 'Topically Driven Neural Language Model' by Lau, Baldwin and Cohn.
tdlmDocsummary :: forall
  (vocSize :: Nat) -- number of words
  (e :: Nat) -- size of the embedding
  (a :: Nat) -- number of features of the document vector summary 
  (n :: Nat) -- length of the document
  (filterSize :: Nat) -- size of the convolution filter
  (t :: NBits) -- size of floats
  .  KnownNat vocSize => KnownNat filterSize => KnownNat e => KnownNat a => KnownNat n => KnownBits t
  => (EmbeddingP vocSize e (Flt t))
  -> (ConvP (Flt t) a e '[filterSize])
  -> DropProb
  -> Gen (T '[n] Int32 -> T '[a] (Flt t))
tdlmDocsummary embs filters dropProb = do
  drpEmb <- mkDropout dropProb
  return $ \document ->
    let embeddedDoc :: Tensor [n,e] (Flt t)
        embeddedDoc = mapT (drpEmb . embedding @e @vocSize embs) document
    in reduceMax axis0 (conv' @'[n] filters embeddedDoc)

tdlmDocsummary' :: forall
  (vocSize :: Nat) -- number of words
  (e :: Nat) -- size of the embedding
  (n :: Nat) -- length of the document
  -- (a :: Nat) -- number of features of the document vector summary 
  -- (filterSize :: Nat) -- size of the convolution filter
  spec
  (t :: NBits) -- size of floats
  proxy
  .  KnownNat vocSize => KnownNat (Ap Frst' spec) => KnownNat e => KnownNat (Ap Scnd' spec) => KnownNat n => KnownBits t
  => proxy spec
  -> (EmbeddingP vocSize e (Flt t))
  -> (ConvP (Flt t) (Ap Scnd' spec) e '[(Ap Frst' spec)])
  -> DropProb
  -> Gen (T '[n] Int32 -> T '[Ap Scnd' spec] (Flt t))
tdlmDocsummary' _proxy  = tdlmDocsummary

scnds :: SList xs -> SList (Ap (FMap Scnd') xs)
scnds Unit = Unit
scnds (_ :* xs) = Proxy :* scnds xs
-- hmap _ Unit = Unit
-- hmap f (x :* xs) = f x :* hmap f xs

mkTdlmDocsummary :: forall
  (vocSize :: Nat) -- number of words
  (e :: Nat) -- size of the embedding
  (spec :: [(Nat,Nat)]) -- (size of the convolution filter,number of features) 
  (n :: Nat) -- length of the document
  (t :: NBits) -- size of floats
  .  KnownNat vocSize => KnownNat e => KnownNat n => KnownBits t
  => All KnownNat (Ap (FMap Scnd') spec)
  => All KnownNat (Ap (FMap Frst') spec)
  => SList spec
  -> DropProb
  -> Gen (T '[n] Int32 -> T '[Sum (Ap (FMap Scnd') spec)] (Flt t))
mkTdlmDocsummary xs0 dropProb = case xs0 of
  Unit -> return (\_ -> zeros)
  (proxy :* xs) -> knownSum' (scnds xs) ?>
                   do embs <- parameterDefault ("embs_topic_" ++ show (sListLength xs))
                      filters <- parameterDefault ("filters_topic_" ++ show (sListLength xs))
                      f <- tdlmDocsummary' @vocSize @e proxy embs filters dropProb
                      fs <- mkTdlmDocsummary @vocSize @e xs dropProb
                      return $ \input -> concat0 (f input) (fs input)

-- | Parameter for topics. This is effectively map from document
-- features (a) to topic representations (vectors of size b) via k
-- topic distributions.
data TopicP t a k b = TopicP {topicDistributions :: (T '[a,k] (Flt t))  -- ^ a linear map from documents features (a) to topic distributions (k)
                             ,topicRepresentations :: (T '[k,b] (Flt t)) -- ^ a linear map from topic distributions (k) to topic representations (b)
                             }

instance (KnownNat a, KnownNat k, KnownNat b, KnownBits t) => KnownTensors (TopicP t a k b) where
  travTensor f s (TopicP x y) = TopicP <$> travTensor f (s<>"_A") x <*> travTensor f (s<>"_B") y
instance (KnownNat a, KnownNat k, KnownNat b, KnownBits t) => ParamWithDefault (TopicP t a k b) where
  defaultInitializer = TopicP <$> glorotUniform <*> glorotUniform

-- | A topic modeler. Described 'Topically Driven Neural Language
-- Model' by Lau, Baldwin and Cohn.  Returns a function converting raw
-- representations (eg. document summaries) to topic representations.
-- This representation can be used as input to a dense layer to
-- predict a word, or as input to an LSTM (initial state) to predict
-- sentences.
mkTdlmTopic :: forall
  (kk :: Nat) -- number of topics
  (a :: Nat) -- document vector summary size
  (b :: Nat) -- topic representation size
  (t :: NBits) -- size of floats
  . KnownNat kk => KnownNat a => KnownNat b => KnownBits t
  => Float -> TopicP t a kk b -> Gen (T '[a] (Flt t) -> (Tensor '[b] (Flt t), Tensor '[kk] (Flt t)))
mkTdlmTopic separationConstant (TopicP topicInput topicOutput) = do
  drpS   <- mkDropout (DropProb 0.1)
  let topicNormalized :: T '[kk,b] (Flt t)
      topicNormalized = mapT normalize topicOutput
      -- matrix of correlation between the topics
      topicCorrelation :: T '[kk,kk] (Flt t)
      topicCorrelation = matmul topicNormalized (transpose01 topicNormalized)
      -- max correlation between two distinct topics
      topicOverlap = reduceMaxAll (square (topicCorrelation ⊝ eye))
  addRegularizer (constant separationConstant ⊙ cast topicOverlap) -- regularizer which ensures that topics are disjoint

  return (\d -> let p :: T '[kk] (Flt t)
                    p = softmax0 (topicInput ∙ d) -- attention distribution (among the topics)
                in (drpS (topicOutput ∙ p), p))



-- | Gating unit which can be used to mix a RNN hidden state with an
-- external information source (eg. topic representation).  Described
-- 'Topically Driven Neural Language Model' by Lau, Baldwin and Cohn;
-- formula (3)
tdlmGatingUnit :: KnownNat n => KnownFloat t => KnownNat m => (GRUP t m n) -> T '[n] t -> T '[m] t -> (T '[m] t)
tdlmGatingUnit (GRUP wz wr w) s h = 
  let x = concat0 h s
      z = sigmoid (wz ∙ x)
      r = sigmoid (wr ∙ x)
      hTilda = tanh (w ∙ (concat0 (r ⊙ h) s))
  in ((ones ⊝ z) ⊙ h + z ⊙ hTilda)


================================================
FILE: TypedFlow/Models/Transformer.hs
================================================
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE NoStarIsType #-}
{-|
Module      : TypedFlow.Models.Transformer
Description : Topic models
Copyright   : (c) Jean-Philippe Bernardy, 2020
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental
-}


module TypedFlow.Models.Transformer where
import Prelude hiding (RealFrac(..))
import TypedFlow.TF
import TypedFlow.Abstract
import TypedFlow.Layers
import TypedFlow.Types
import TypedFlow.Types.Proofs ((?>), knownSum')
import GHC.TypeLits

-- Convention for type variables:
-- h = number of heads
-- e = embedding size
-- n = sequence length

average :: forall e. KnownNat e => T '[e] Float32 -> Scalar Float32
average = reduceMeanAll

-- | Normalise a vector. But add a small epsilon to avoid division by zero
normalizer :: forall e. KnownNat e => T '[e] Float32 -> T '[e] Float32
normalizer x = mapT (⊘ (sigma + epsilon)) xmu -- so the norm of result is almost 1
  where mu = average x
        xmu = mapT (⊝ mu) x  -- so the average of xmu is 0
        sigma = sqrt (average (square xmu)) -- the norm of xmu.
        epsilon = 0.001 -- ?

-- Informally:
-- mapT f x = vector y such that y_i = f (x_i) -- (the first axis)

dimAsFloat :: forall e. KnownNat e => Float
dimAsFloat = fromIntegral (knownNatVal (natSat @e))

-- | dot product attention on one key (k)
dotAttention1 :: forall e n. KnownNat e => KnownNat n
  => T '[e,n] Float32 -> T '[n,e] Float32 -> T '[e] Float32 -> T '[e] Float32
dotAttention1 q v k = v ∙ softmax0 (mapT (⊘ normFactor) (q ∙ k))
  where normFactor = constant (sqrt (dimAsFloat @e))

-- | dot product attention for every position
dotAttention :: forall n e. KnownNat n => KnownNat e
  => T '[n,e] Float32 -> T '[n,e] Float32 -> T '[n,e] Float32 -> T '[n,e] Float32
dotAttention v k q = mapT (dotAttention1 (transpose01 q) v) k

-- | h copies of a dense layer (the same for every copy).
multiheadLinearEncoder :: forall h e. KnownNat e => KnownNat h =>
  String -> Gen (T '[e] Float32 -> T '[h,e] Float32)
multiheadLinearEncoder name = do
  wv <- parameterDefault ("w_" ++ name)
  return $ \x -> reshape (wv # x)

multiheadSelfAttentionModule
  :: forall h n e. KnownNat n => KnownNat h => KnownNat e
  => String -> Gen (T '[n,e] Float32 -> T '[n,e] Float32)
multiheadSelfAttentionModule nm = do
  ev <- multiheadLinearEncoder @h ("v" ++ nm)
  eq <- multiheadLinearEncoder @h ("q" ++ nm)
  ek <- multiheadLinearEncoder @h ("k" ++ nm)
  w1 <- parameterDefault ("w1" ++ nm)
  -- w2 <- parameterDefault ("w2" ++ nm)
  return $ \x ->
    let v = transpose01 (mapT ev x)
        q = transpose01 (mapT eq x)
        k = transpose01 (mapT ek x)
        r :: T '[n,h,e] Float32
        r = transpose01 (zipWith3T dotAttention q k v)
        r' = mapT (dense @e w1 . reshape @'[h * e]) r
    in mapT ({-dense w2 . -}normalizer) (r' + x)
       -- x + mapT normalizer r'

multiheadSelfAttentionModuleDecoder
  :: forall h n e. KnownNat n => KnownNat h => KnownNat e
  => String -> Gen (T '[n,e] Float32 -> T '[n,e] Float32  -> T '[n,e] Float32)
multiheadSelfAttentionModuleDecoder nm = do
  ev <- multiheadLinearEncoder @h ("v" ++ nm)
  eq <- multiheadLinearEncoder @h ("q" ++ nm)
  ek <- multiheadLinearEncoder @h ("k" ++ nm)
  w1 <- parameterDefault ("w1" ++ nm)
  -- w2 <- parameterDefault ("w2" ++ nm)
  return $ \x    -- comes from decoder
            y    -- comes from encoder
           ->
    let k = transpose01 (mapT ek y)
        v = transpose01 (mapT ev x)
        q = transpose01 (mapT eq y)
        r :: T '[n,h,e] Float32
        r = transpose01 (zipWith3T dotAttention q k v)
        r' = mapT (dense @e w1 . reshape @'[h * e]) r
    in mapT ({-dense w2 . -}normalizer) (r' + x)
       -- x + mapT normalizer r'


feedForwardModule :: forall e. KnownNat e
  => String -> Gen (T '[e] Float32 -> T '[e] Float32)
feedForwardModule nm = do
  w1 :: DenseP Float32 e e <- parameterDefault (nm ++ "w1")
  w2 <- parameterDefault (nm ++ "w2")
  return $ \x -> normalizer (x + (w2 # relu (w1 # x)))

encoderModule :: forall h n e. KnownNat n => KnownNat h => KnownNat e => DropProb 
  -> String -> T '[n,e] Float32 -> Gen (T '[n,e] Float32 -> T '[n,e] Float32)
encoderModule dropProb nm positionalTensor = do
  drp <- mkDropout dropProb
  selfAtt <- multiheadSelfAttentionModule @h (nm ++ "mh")
  ff <- feedForwardModule (nm ++ "ff")
  return (mapT ff . selfAtt . (+ positionalTensor) . drp)

positionalModuleSinCos :: forall n e. KnownNat e => KnownNat n => T '[n,e] Float32
positionalModuleSinCos = sin (transpose01 (broadcastT pos) * (broadcastT omega) + broadcastT phase)
  where pos = (cast (range @n @'B32)) :: T '[n] Float32
        phase = cast ((range @e @'B32) `floorMod` constant 2) * (constant pi/2) :: T '[e] Float32
        omega = constant (log 10000) * exp (constant (-2.0 / dimAsFloat @e) * cast (range @e @'B32))
        -- Note I'm not dividing the frequence by 2 because integer
        -- division isn't implemented. Should not have any consequence.

positionalModuleLearned :: KnownNat e => KnownNat n => Gen (T '[n,e] Float32)
positionalModuleLearned = do
  e <- parameterDefault "positional"
  return $ let EmbeddingP x = e in x

encoderStack :: forall h n e. KnownNat h => KnownNat n => KnownNat e
  => DropProb -> Int -> Gen (T '[n,e] Float32 -> T '[n,e] Float32)
encoderStack dropProb n = do
  p <- positionalModuleLearned
  encoders <- mapM (\i -> encoderModule @h dropProb ("enc" ++ show i) p) [1..n]
  return (foldr (.) id encoders) -- n-ary function composition


================================================
FILE: TypedFlow/Python.hs
================================================
{-# LANGUAGE ViewPatterns #-}
{-|
Module      : TypedFlow.Python
Description : Python-generation Functions 
Copyright   : (c) Jean-Philippe Bernardy, 2017
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental

-}

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}

module TypedFlow.Python (compile, compileGen, generateFile) where

import Data.Char (toLower)
import Data.Proxy
import Data.List (genericReplicate, )
import GHC.TypeLits
import Control.Monad.State
import TypedFlow.Types
import TypedFlow.Broadcast (permToFun,unopInputShape)
import TypedFlow.Types.Proofs
import TypedFlow.Memo
import Prettyprinter as PP
import Prettyprinter.Render.String as PP
import qualified Data.Map as M
import TypedFlow.Learn
import qualified Data.Sequence as S
import Data.Sequence (Seq, (|>), )
import Data.Foldable

first :: (t -> a) -> (t, b) -> (a, b)
first f (x,y) = (f x,y)

paramShape' :: VarInfo -> [Integer]
paramShape' (VarInfo {varRef = Ref _ s _}) = shapeToList' s

paramDType ::  VarInfo -> Typ
paramDType (VarInfo {varRef = Ref _ _ t}) = sTypTyp t

paramName :: VarInfo -> String
paramName (VarInfo {varRef = Ref {..}}) = refName


generateFile :: String -> Python [VarInfo] -> IO ()
generateFile fname g = do
  putStrLn ("Parameters (total " ++ show (sum [product (paramShape' p) | p <- params]) ++ "):")
  forM_ params printParam
  writeFile fname output
  where (output,params) = generate g
        printParam p = putStrLn (paramName p ++ ": " ++ "T " ++ renderSimple (showShape' (paramShape' p))  ++ " " ++ showT (paramDType p))

named :: String -> DOC -> DOC
named fname x = text (fname <> "=") <> x

text :: String -> DOC
text = pretty

genFun :: forall b. String -> [DOC] -> Python b -> Python b
genFun name args body = do
  gen (text "def " <> text name <> align (tuple args) <> text ":")
  withDOC (\b -> "  " <> align b) body


showTyp :: forall t. KnownTyp t => DOC
showTyp = text (showT (typVal @t))

showSTyp :: forall t. STyp t -> DOC
showSTyp t = knownTyp t $ showTyp @t

showT :: Typ -> [Char]
showT (Typ Bool _) = "tf.bool"
showT (Typ Cmplx B32) = "tf.complex64"
showT (Typ Cmplx B64) = "tf.complex128"
showT (Typ k l) = "tf." ++ map toLower (show k) ++ drop 1 (show l)

showShape' ::  [Integer] -> DOC
showShape' s = list (map (showDim' "None") s)

showShape :: ∀ (s :: Shape). All KnownNat s => SList s -> DOC
showShape s = showShape' (shapeToList'' s)

showSShape :: ∀ (s :: Shape). SShape s -> DOC
showSShape s = showShape' (shapeToList' s)

showShapeType :: ∀ (s :: Shape). KnownShape s => DOC
showShapeType = showSShape (typeSShape @s)

-- | Show a shape, but "None" is replaced by "-1"
showShapeMinus :: forall (s::Shape) proxy. All KnownNat s => SList' proxy s -> DOC
showShapeMinus s = list (map (showDim' "-1") (shapeToList'' s))

showShapeLen :: ∀ (s::Shape). KnownLen s => DOC
showShapeLen = (text . show) (listTypeLen @ s)

showDim' :: String -> Integer -> DOC
showDim' none n = text (if n == 514229 then none else show n)

showDimM :: forall n. KnownNat n => DOC
showDimM = showDim' "-1" (natVal (Proxy @ n))

showDim :: forall n. KnownNat n => DOC
showDim = showDim' "None" (natVal (Proxy @ n))

showDimS :: forall n. Sat KnownNat n -> DOC
showDimS Sat = showDim @n

gen :: DOC -> Python ()
gen s = modify $ \PyState{..} -> PyState {genText=genText |> s,..}

setGen :: Seq DOC -> Python ()
setGen d = modify $ \PyState{..} -> PyState {genText=d,..}

(<--) :: Ref Int s t -> UntypedExpression -> Python ()
x <-- y = gen (pyVarRepr x <> text "=" <>  y)


renderSimple :: Doc ann -> String
renderSimple = renderString . layoutPretty (LayoutOptions Unbounded)

-- | save an intermediate result to a variable and save it to
-- genAssignTable for future re-use.
cache :: forall s t. KnownTyp t => KnownShape s => DOC  -> Python DOC
cache x = do
  let x' = renderSimple x
  mcache <- M.lookup x' <$> gets genAssignTable
  case mcache of
    Just y -> do
      -- comment ("cache hit: " <> text x')
      return y
    Nothing -> do
      -- comment ("cache miss")
      v <- newPyVar @s @t
      comment ("shape: " <> (showShapeType @s))
      v <-- x
      modify $ (\g -> g {genAssignTable = M.insert x' (pyVarRepr v) (genAssignTable g)})
      return (pyVarRepr v)

newPyVar' :: forall s t. SShape s -> STyp t -> Python (Ref Int s t)
newPyVar' s t = knownSShape s ?> (knownTyp t $ newPyVar @s @t)

newId :: Python Integer
newId = do
  n <- gets genId
  modify $ \PyState{..} -> PyState {genId=genId+1,..}
  return n

newPyVar :: forall s t. KnownShape s => KnownTyp t => Python (Ref Int s t)
newPyVar = do
  n <- newId
  return $ Ref (fromIntegral n) typeSShape typeSTyp

pyVarInfoRepr :: VarInfo -> DOC
pyVarInfoRepr i = text (varName i)

pyVarRepr :: Ref Int s t -> DOC
pyVarRepr (Ref n _ _) = text ("var" <> show n)

tuple :: [DOC] -> DOC
tuple = parens . align . sep . punctuate comma
dict :: [(String,DOC)] -> DOC
dict xs = braces $ align $ sep $ punctuate comma [text (show k) <> ":" <> v | (k,v) <- xs]

funcall :: String -> [DOC] -> DOC
funcall = funcall' . text

funcall' :: DOC -> [DOC] -> DOC
funcall' f args =  f <> tuple args

comment :: DOC -> Python ()
comment c = gen ("#" <> c)

func :: String -> [DOC] -> [(String,DOC)] -> DOC
func fname positional namedArgs = funcall fname (positional ++ map (uncurry named) namedArgs )

withDOC :: forall a. (DOC -> DOC) -> Python a -> Python a
withDOC f g = do
  before <- gets genText
  setGen mempty
  x <- g
  after <- gets genText
  setGen (before |> f (vcat $ toList after))
  return x

generate :: Python [VarInfo] -> (String,[VarInfo])
generate s = (renderString (layoutPretty (LayoutOptions (AvailablePerLine 92 1)) (vcat $ toList genText)),
              genPyVars)
  where (genPyVars,PyState{..}) = runState s initPyState
        initPyState = PyState {genPureTable = mempty
                              ,genAssignTable = mempty
                              ,genText = mempty
                              ,genId = 10000}

generatePure :: forall s t. KnownTyp t => KnownShape s => T s t -> Python DOC
generatePure x = do
  let sn = makeSn2 x
  mv <- snMapLookup2 sn <$> gets genPureTable
  case mv of
    Just v -> do
        -- comment ("gp hit:" <> v)
        return v
    Nothing -> do
      -- comment ("gp miss")
      e <- generatePure' (\s x' -> knownSShape s ?> generatePure x') typeSShape x
      v <- cache @s @t e
      modify (\g -> g {genPureTable = (snMapInsert2 sn v) (genPureTable g)})
      return v

genDistr :: forall s s0 t. KnownTyp t => Distribution s t -> SShape s0 -> SShape s -> DOC
genDistr d sh s1 = case d of
  TruncatedNormalD stddev -> funcall "tf.random.truncated_normal"
    [showSShape (sh .+. s1), named "stddev" (float stddev), named "dtype" (showTyp @t)]
  UniformD low high -> funcall "tf.random.uniform" [showSShape (sh .+. s1)
                                ,named "minval" (float low)
                                ,named "maxval" (float high)
                                ,named "dtype" (showTyp @t)]
  OrthogonalD ->
    funcall' (funcall "tf.keras.initializers.orthogonal" []) [named "dtype" (showTyp @t), named "shape" (showSShape (sh .+. s1))]

generatePure' :: forall s t. KnownTyp t => (forall s' t'. KnownTyp t' => SShape s' -> T s' t' -> Python DOC) -> SShape s -> T s t -> Python DOC
generatePure' rec sR = knownSShape sR ?> \case
  Unbroadcast{} -> error "broadcasting operation did not complete (Unbroadcast)!"
  BroadcastT _ _ _ sh x -> --- error "broadcasting operation did not complete (BroadcastT)!"
    do
     -- debug help
     rx <- rec sh x
     return (funcall "ERROR:BroadcastT" [rx])
  MapT {} -> error "broadcasting operation did not complete (mapT)!"
  ZipT {} -> error "broadcasting operation did not complete (ZipT)!"
  Zip3T {} -> error "broadcasting operation did not complete (Zip3T)!"
  If c x y -> do
    rc <- rec typeSShape c
    rx <- rec typeSShape x
    ry <- rec typeSShape y
    return (func "tf.cond" [rc] [("true_fn", lambda0 rx) ,("false_fn", lambda0 ry)])
    where lambda0 z = text "lambda: " <> z
  -- if broadcast_to is broken: https://github.com/tensorflow/tensorflow/issues/21901
  -- DirectBroadcast s0 s1 s2 s3 x -> do
  --  recx <- rec (s0 .+. s2) x
  --  let expanded = func "tf.reshape" [recx,list (map (showDim' "-1")
  --         (concat [shapeToList' s0, genericReplicate (sListLength s1) 1
  --                 ,shapeToList' s2, genericReplicate (sListLength s3) 1 ]))] []
  --  return (funcall "tf.add" [expanded, func "tf.zeros" [showSShape sR] [("dtype", showTyp @t)]])
  DirectBroadcast s0 s1 s2 s3 x -> do
   recx <- rec (s0 .+. s2) x
   let expanded = func "tf.reshape" [recx,list (map (showDim' "-1")
          (concat [shapeToList' s0, genericReplicate (sListLength s1) 1
                  ,shapeToList' s2, genericReplicate (sListLength s3) 1 ]))] []
   return (funcall "tf.broadcast_to" [expanded, showSShape sR])
  Noise noiseId s0 s1 x -> do
    return $ (genDistr x s0 s1) <+> (text "# " <> integer noiseId)
  T op -> return $ case op of
    ExternalVar (Ref v _ _) -> text v
    Variable v -> pyVarRepr v
    (Constant c) -> funcall "tf.constant" [prettyT @t c, named "shape" (showSShape sR), named "dtype" (showTyp @t)]
    (Range n@Sat) -> (func "tf.range" [] [("start",integer 0),
                               ("limit",integer (natVal n)),
                               ("dtype",showTyp @t)])
  Where c x y -> do
    rc <- rec typeSShape c
    rx <- rec typeSShape x
    ry <- rec typeSShape y
    return (funcall "tf.where" [rc, rx, ry])
  UnOp operation s0  x -> do
   recx <- rec (s0 .+. unopInputShape operation) x
   return $ case operation of
    Diag _ -> funcall "tf.matrix_diag" [recx]
    Cast -> funcall "tf.cast" [recx,showTyp @t]
    StopGradient -> funcall "tf.stop_gradient" [recx]
    ExpM _  -> funcall "tf.linalg.expm" [recx]
    ZeroTriangle _ side k  -> funcall ("tf.experemental.numpy.tri" ++ case side of Upper -> "u"; Lower -> "l") [recx, integer k]
    Conjugate -> funcall "tf.math.conj" [recx]
    RealPart -> funcall "tf.math.real" [recx]
    Axis1Op _ (SliceOp _ _ lo hi) -> recx <> list (replicate (fromIntegral (sListLength s0)) (text ":") ++ [integer lo <> text ":" <> integer hi])
    Axis1Op _ (AccessOp _ idx) -> recx <> list (replicate (fromIntegral (sListLength s0)) (text ":") ++ [integer idx])
    Axis1Op _ op' ->
       let (op,args) = case op' of
                         SliceOp {} -> error "Python: panic: sliceop is special"
                         AccessOp {} -> error "Python: panic: accessop is special"
                         ReverseT _ -> ("tf.reverse",[])
                         OneHot depth -> ("tf.one_hot",[("dtype",showTyp @t), ("depth", showDimS depth)])
                         ArgMax{} -> ("tf.argmax",[("output_type",showTyp @t)])
                         ReduceOp _ r -> ("tf.reduce_" ++ rop, [])
                            where rop = case r of
                                           Max -> "max"
                                           Min -> "min"
                                           Sum -> "sum"
                                           Mean -> "mean"
           axisName = if op == "tf.nn.softmax" then "dim" else "axis"  -- use dim before TF 1.5
           useAxisList = case op' of ReverseT _ -> True; _ -> False
       in func op [recx] ((axisName,(if useAxisList then (list . (:[])) else id) (integer (sListLength s0))):args)
    Float1Op op' -> funcall op (recx:args)
       where (op,args) = case op' of
                HardSigmoid -> ("tf.keras.backend.hard_sigmoid",[])
                Relu -> ("tf.nn.relu",[])
                ClipByValue lo hi -> ("tf.clip_by_value",[float lo,float hi])
                _ -> ("tf." ++ map toLower (show op'), [])
    Num1Op op' -> funcall op (recx:args)
       where (op,args) = case op' of
                Negate -> ("tf.negative",[])
                _ -> ("tf." ++ map toLower (show op'), [])
  MatMul s0 a b c x y  -> do
    recx <- rec (s0 .+. (:*) a ((:*) b Unit)) x
    recy <- rec (s0 .+. (:*) b ((:*) c Unit)) y
    return (funcall "tf.matmul" [recx, recy])
  BinOp operation s0 s1 _ s2 _ x y -> do
   recx <- rec (s0 .+. s1) x
   recy <- rec (s0 .+. s2) y
   return $ case operation of
     Simple2Op sop  -> let pop = case sop of
                                   MkComplex -> "tf.complex"
                                   Add -> "tf.add"
                                   Divide -> "tf.divide"
                                   IntegerDiv -> "tf.math.floordiv"
                                   Equal -> "tf.equal"
                                   Subtract -> "tf.subtract"
                                   Multiply -> "tf.multiply"
                                   Minimum -> "tf.minimum"
                                   Maximum -> "tf.maximum"
                                   Comparision op -> "tf.math." ++ case op of
                                     Less -> "less"
                                     Greater -> "greater"
                                     LessOrEqual -> "less_equal"
                                     GreaterOrEqual -> "greater_equal"
                                   Logic op -> "tf.math.logical_" ++ case op of
                                      And -> "and"
                                      Or -> "or"
                                   FloorMod -> "tf.math.floorMod"
                       in funcall pop [recx,recy]
     SigmoidCrossEntropyWithLogits -> func "tf.nn.sigmoid_cross_entropy_with_logits" [] [("labels",recx),("logits",recy)]
     SparseSoftmaxCrossEntropyWithLogits -> func "tf.nn.sparse_softmax_cross_entropy_with_logits" []  [("labels",recx),("logits",recy)]
     SoftmaxCrossEntropyWithLogits -> func "tf.nn.softmax_cross_entropy_with_logits" []   [("labels",recx),("logits",recy)] -- FIXME: use _v2 for TF 1.5
  ReshapeFrom s t -> do
    rt <- rec s t
    return (funcall "tf.reshape" [rt, showShapeMinus sR])
  Concat s0 s1 xs -> do
    let go :: forall s0 s1 ns. SShape s0 -> SShape s1 -> NP (Catable s0 s1 t) ns -> Python [DOC]
        go _ _ Unit = return []
        go s0' s1' (Catable n y :* ys) = (:) <$> rec (s0' .+. n :* s1') y <*> go s0' s1' ys
    rxs <- go s0 s1 xs
    return (funcall "tf.concat" [list rxs, text "axis=" <> integer (sListLength s0)])
  Transpose s p x -> do
    rx <- rec s x
    comment ("transpose: p = " <> text (show p) <> "; " <> text (show s))
    return (func "tf.transpose" [rx] [("perm",list (map (integer . permToFun p) [0.. sListLength s-1]))])
  Gather indexShape s0 m s1 x ix -> do
    rx <- rec (s0 .+. ((:*) m s1)) x
    rix <- rec (s0 .+. indexShape) ix
    return (func "tf.gather" [named "params" rx, named "indices" rix, named "batch_dims" (integer (sListLength s0)), named "axis" (integer (sListLength s0))] [])
  GatherND containerShape elementShape indexShape x ix -> do
    rx <- rec (containerShape .+. elementShape) x
    rix <- rec (indexShape *: (sListLenAsNat containerShape)) ix
    return (func "tf.gather_nd" [rx, rix] [])
  Convolution bs inChans outChans filterShape s0 x filters -> do
    recx <- rec ((:*) bs (s0 *: inChans)) x
    recFilters <- rec (filterShape .+. ((:*) inChans ((:*) outChans Unit))) filters
    return (func "tf.nn.convolution" [recx, recFilters] [("padding",text (show ("SAME"::String))),("data_format", text (show dataFormat))])
   where dataFormat = case sListLength filterShape of
           1 -> ("NWC" :: String)
           2 -> "NHWC"
           3 -> "NDHWC"
           _ -> error "convolution: more than 3 spatial dimensions are not supported!"
  Pool bs window typ numChans outSpatial x -> do
     rx <- rec ((:*) bs (zipWithMulSShapes window outSpatial .+. (:*) numChans Unit)) x
     return (func "tf.nn.pool"
                  [rx, showSShape window, typ']
                  [("strides", showSShape window),
                   ("padding",text (show ("SAME" :: String)))])
   where typ' = text $ (show $ case typ of MaxPool -> "MAX"; AvgPool -> "AVG" :: String)
  Softmax _ _ x -> do
     rx <- rec typeSShape x
     return $ func "tf.nn.softmax" [rx] [("axis","1")]
  -- _ -> error "Python compiler: case not covered"
type Python a = State PyState a

generateParameters :: [VarInfo] -> Python [DOC]
generateParameters genVars = do
  -- generate variables
  forM genVars $ \v -> case v of
      VarInfo {..} -> case varRef of
        Ref refId shap typ -> do
          ii <- case varInitial of
            Nothing -> return []
            Just iii -> do
              iiii <- case knownSShape shap of
                Sat -> knownTyp typ $ generatePure iii
              return [named "initial_value" iiii]
          var <- newPyVar' shap typ
          var <-- funcall "tf.Variable" ([named "name" (string refId), named "trainable" (bool varTrainable)] ++ ii)
          return (pyVarRepr var)

-- | Clip a gradient
clipByGlobalNorm :: Float -> UntypedExpression -> UntypedExpression
clipByGlobalNorm maxNorm x = funcall "tf.clip_by_global_norm" [x,float maxNorm] <> brackets (int 0)
 -- clip_by_global_norm returns a couple (clipped grads, global_norm)

-- | Gradient of wrt. given parameters.
grad :: UntypedExpression -> UntypedExpression -> UntypedExpression
grad y vars = funcall "tf.gradients" [y, vars]


fnToPython ::[VarInfo] -> PreparedFunction -> Python ()
fnToPython params PreparedFunction{pfInputs = SomeSuch placeHolders,
                                   pfOutputs = SomeSuch returned,..} = do 
  -- we can't re-use intermediate computations from initialisers or other functions:
  modify $ \PyState {..} -> PyState {genPureTable = mempty, genAssignTable = M.empty,..}
  gen (text "@tf.function")
  genFun (pfName <> "_fn") (text "training_placeholder":
                  map pyVarInfoRepr params ++
                  hMapToList @KnownPlaceholder (text . placeholderName) placeHolders) $
    do returns <- hfor @KnownPlaceholder returned $ \ph@(PHT x) -> do
         r <- generatePure x
         return (placeholderName ph,r)
       gen (text "return " <> dict returns)
       return ()
  gen (text pfName <> " = " <>
        dict [
          ("function",text pfName <> "_fn"),
          ("batched",bool pfBatched),
          ("placeholders",dict (hMapToList @KnownPlaceholder
        (\ph -> case placeHolderRef ph of
                  Ref nm shape typ ->
                    (nm, dict [("shape",showSShape shape), ("dtype",showSTyp typ)]))
        placeHolders))])
  return ()
  
toPython :: PreparedModel -> Python ()
toPython PreparedModel {..} = do
  gen (text "import tensorflow as tf")
  -- Static stuff: construct and initialise parameters, list placeholders, etc.
  genFun "mkModel" [] $ do
    vs <- generateParameters pmParams
    gen (text "return " <>
         dict [("batch_size", integer pmBatchSize)
              ,("parameters",list vs)
              ,("paramsdict",dict [(varName p, v) | (p,v) <- zip pmParams vs])])
  -- Loss/Accur/Predict function
  forM_ pmFunctions (fnToPython pmParams)
  return ()

-- | Batchify and compile a model with simple input to output mapping.
compile :: forall batchSize sx tx sy ty sy_ ty_ p
        .  (KnownNat batchSize, KnownShape sx, KnownTyp tx, KnownShape sy, KnownTyp ty, KnownShape sy_, KnownTyp ty_, KnownShape p, KnownLen p)
        => Options
        -> Gen (Tensor sx tx -> Tensor sy ty -> ModelOutput  ty_ p sy_)
        -> Python [VarInfo]
compile options fGen = knownSShape (typeSShape @sy_ .+. typeSShape @p) ?> compileGen @batchSize options (sequenceA [simpleModel @p <$> fGen])

-- | Batchify and compile a model with generic  input to output mapping and states
compileGen :: forall bs. (KnownNat bs)
           => Options
           -> Gen [Function]
           -> Python [VarInfo]
compileGen options model = toPython pm >> return pmParams
  where pm@PreparedModel{..} = prepare @bs model



prettyT :: forall t. KnownTyp t => HaskType t -> DOC
prettyT = case kindVal @(TypKind t) of
  SInt -> case bitsVal @(TypBits t) of
    SB32 -> int . fromIntegral
    SB64 -> int . fromIntegral
  SBool -> bool
  SFloat -> case bitsVal @(TypBits t) of
    SB32 -> float
    SB64 -> double



data PyState = PyState {genId :: Integer
                       ,genText :: S.Seq DOC
                       ,genPureTable :: SSNMap2 Shape Typ T DOC
                       -- ^ Table mapping pointers to their
                       -- interpretations, so that sharing in the data
                       -- structures can be exploited when generating
                       ,genAssignTable :: M.Map String DOC
                       -- ^ Table mapping expressions to variables, so
                       -- that lost sharing can be recovered
                       -- genPeeks :: [(String,UntypedExpression)]
                       }

type UntypedExpression = DOC
type DOC = Doc ()

double :: Double -> DOC
double = pretty
float :: Float -> DOC
float = pretty
integer :: Integer -> DOC
integer = pretty
int :: Int -> DOC
int = pretty
bool :: Bool -> DOC
bool = pretty
string :: String -> DOC
string = dquotes . text 


================================================
FILE: TypedFlow/TF.hs
================================================
{-# LANGUAGE InstanceSigs #-}
{-|
Module      : TypedFlow.TF
Description : Binding to tensorflow functions
Copyright   : (c) Jean-Philippe Bernardy, 2017
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental

This module provides direct access to the most commonly used
TensorFlow functions. Higher-level functions are not defined here.
-}

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE NoStarIsType #-}

module TypedFlow.TF (
  -- * Variables, Parameters
  -- ** Parameters
  parameter',
  parameter,
  parameterDefault,
  ParamWithDefault(..),
  -- getParameters,
  -- ** Persistent variables
  persistent,
  modifyPersistent,
  -- ** Placeholders and outputs
  -- placeholder,
  -- peekAt,
  -- peekAtMany,
  -- * Operations
  -- ** Constants
  zeros,
  ones,
  eye,
  constant,
  -- ** indexwise unary operators
  round, sigmoid, relu, floor, square,
  -- ** Indexwise binary operators
  addN, (⊕), (⊝), (⊙), (⊘), equal,
  minT, maxT,
  -- ** Products
  (∙), (·), matmul,
  -- ** Reducers
  reduceMeanAll, reduceSumAll, reduceMinAll, reduceMaxAll,
  reduceSum, reduceMean, reduceMin, reduceMax,
  -- argmax,
  argmax0, argmax1,
  softmax0, softmax1,
  -- ** Gradients
  -- grad,
  -- clipByGlobalNorm,
  clipByValue,
  -- ** Indexing
  last0, nth0, nth0', lookupT, lookupManyT, gather, range, reverseT,
  -- ** Split and concatenate
  slice, slice0, slice1,
  litStack0,
  stack0, unstack0,
  stack1,
  concatT, concat0, concat1,
  consT0, snocT0,
  headT0, tailT0, initT0,
  -- ** Reshaping
  expandDim,
  expandDim0, squeeze0,
  expandDim1, 
  flatten2, flatten3, flatten12, flattenN2,
  inflate2, inflate3, inflate12,
  reshape, flattenAll, inflateAll,
  -- ** Transposition
  transposeN, transposeN', transpose01, transposeN01,
  -- ** Sequences
  sequenceMask,
  -- ** Convolutions
  convolution,
  -- ** Misc
  norm, normalize,
  stopGradient,
  cast,
  oneHot0, oneHot1,
  -- ** complex numbers
  expm, conjugate, realPart,
  -- ** Triangular and band Matrices
  tril, triu, fillTriangular, fillUpperTriangular,
  -- ** Testing conditions
  if_, where_, lessThan,
  -- * Contrib
  -- ** Mapping
  mapT, zipWithT, zipWith3T,
  mapTT, zipWithTT,
  -- ** Losses
  sigmoidCrossEntropyWithLogits,
  softmaxCrossEntropyWithLogits,
  sparseSoftmaxCrossEntropyWithLogits,
  -- ** Initializers
  noise,
  Distribution(..),
  varianceScaling, glorotUniform,

  -- ** Heterogeneous vectors
  repeatT,

  -- ** Heterogeneous heterogeneous vectors
  repeatHT
  ) where

import Prelude hiding (RealFrac(..))
import GHC.TypeLits
import Data.Proxy
import TypedFlow.Types
import TypedFlow.Types.Proofs
import TypedFlow.Abstract
import TypedFlow.Broadcast

-- | Repeat a flexible-shape constant vector to form a heterogeneous tensor vector.
repeatT :: forall (ss :: [Shape]) t. All KnownShape ss => KnownLen ss =>
           (forall s. KnownShape s => T s t) -> HTV t ss
repeatT f = zs (typeSList @ss)
  where zs :: forall (s :: [Shape]). All KnownShape s => SList s -> HTV t s
        zs Unit = Unit
        zs (_ :* n) = F f :* zs n

-- | Repeat a flexible-shape constant vector to form a heterogeneous tensor vector.
repeatHT :: forall ss. All KnownPair ss => KnownLen ss =>
           (forall s t. KnownShape s => KnownTyp t => T s t) -> HHTV ss
repeatHT f = zs (typeSList @ss)
  where zs :: forall s. All KnownPair s => SList s -> HHTV s
        zs Unit = Unit
        zs (_ :* n) = Uncurry f :* zs n

-- | Declare a parameter to optimize.
parameter' :: ∀ (shape :: Shape) t. (KnownTyp t,KnownShape shape) => String -> T shape t -> Gen (T shape t)
parameter' = persistent True

-- | Create a parameter.
parameter :: forall p. KnownTensors p => String -> Gen p -> Gen p
parameter s p = travTensor parameter' s =<< p

-- | Declare variable which persists between calls to session.run.
persistent :: ∀ (shape :: Shape) t. (KnownTyp t,KnownShape shape) => Bool -> String -> T shape t -> Gen (T shape t)
persistent trainable name initial = do
  T . ExternalVar <$> GPVariable trainable name (Just initial)


-- | Modify a mutable tensor. Attention: for the assignment to happen,
-- the resulting tensor must be evaluated!
modifyPersistent :: (KnownShape s,KnownTyp t) => T s t -> T s t -> Gen (T s t)
modifyPersistent (T (Variable v)) x = GPModify v x -- FIXME: pattern matching here is poor style.

-- type family AddSpatialDims xs ys where
--   AddSpatialDims '[x] '[] = '[x]
--   AddSpatialDims (x ': xs) (y ': ys) = (x+(y-1)) ': AddSpatialDims xs ys

-- -- | Convolution operation with no padding (applying the filter only on positions where the input is fully defined)
-- convolutionValid :: forall outputChannels filterSpatialShape inChannels s t.
--                KnownLen filterSpatialShape
--             => Length filterSpatialShape <= 3
--             => ((1 + Length filterSpatialShape) ~ Length s) -- the last dim of s is the batch size
--             => T (inChannels ': AddSpatialDims s filterSpatialShape) t -- ^ input tensor (batched)
--             -> T ('[outputChannels,inChannels] ++ filterSpatialShape) t -- ^ filters
--             -> T (outputChannels ': s) t
-- convolutionValid = untypedConvolution "VALID"

-- poolNC :: forall dim s inputSpatialShape channels batchSize t.
--                   (inputSpatialShape ~ Take dim s, '[batchSize] ~ Drop dim s) =>
--                   T ('[channels] ++ s) t ->
--                   Vec dim  -> String -> String -> 
--                   T ('[channels] ++ s) t
-- poolNC (T input) windowShape poolingType padding =
--    T (funcall "tf.nn.pool" [input,list (map float (vecToList windowShape)),text poolingType,text padding,named "data_format" (text "NWC")])

-- Difficulty: relate windowSize, inputSpatialShape, outputSpatialShape




---------------------------
-- Contrib
data VarianceScaleMode = VSFanIn | VSFanOut | VSAvg
data Distrib = NormalDistr | UniformDistr

-- | Random tensor with variance scaling according to deeplearning lore.
varianceScaling :: forall inDim outDim t. KnownNat inDim => (KnownNat outDim, KnownFloat t) =>
   Float -> VarianceScaleMode -> Distrib -> Gen (Tensor '[inDim,outDim] t)
varianceScaling factor mode distr = noise $ case distr of
                                   UniformDistr -> UniformD (-limit) limit
                                   NormalDistr -> TruncatedNormalD limit
  where
    fan_in = fromIntegral (natVal (Proxy @inDim))
    fan_out = fromIntegral (natVal (Proxy @outDim))
    n = max 1 $ case mode of
                  VSFanIn -> fan_in
                  VSFanOut -> fan_out
                  VSAvg -> (fan_in + fan_out) / 2
    limit = sqrt ((case distr of NormalDistr -> 1.3; UniformDistr -> 3) * factor / n)


glorotUniform :: forall inDim outDim t. KnownNat inDim => (KnownNat outDim, KnownBits t) => Gen (Tensor '[outDim,inDim] ('Typ 'Float t))
glorotUniform = varianceScaling 1 VSAvg UniformDistr

-- | 'cons' an element and an array (in the first dimension)
consT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n =>  T s t -> T (n ': s) t -> T (n+1 ': s) t
consT0 x xs = plusComm @1 @n #> concat0 (expandDim0 x) xs

-- | 'snoc' an element and an array (in the first dimension)
snocT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n =>  KnownLen s => T (n ': s) t -> T s t -> T (n+1 ': s) t
snocT0 xs x = concat0 xs (expandDim0 x)

headT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n =>  T (n+1 ': s) t -> T (s) t
headT0 xs = nth0 0 xs

tailT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n =>  T (n+1 ': s) t -> T (n ': s) t
tailT0 xs = incrPos @n              #> -- 0 < n+1
            plusMinusAssoc @n @1 @1 #> -- (n+1) - 1 = -- n+ (1 - 1)
            slice0 @1 @(n+1) xs

initT0 :: forall n s t. KnownTyp t => KnownShape s => KnownNat n =>  T (n+1 ': s) t -> T (n ': s) t
initT0 xs = plusMono @n @1 #> -- n <= n+1
            slice0 @0 @n xs

----------------
-- Helpers

-- | Product of a matrix of weights with a vector.
(∙) :: (KnownNumeric t, KnownNat cols, KnownNat rows, KnownTyp t) => Tensor '[cols, rows] t -> Tensor '[cols] t -> Tensor '[rows] t
m ∙ v = squeeze0 (matmul (expandDim0 v) m)
infixl 7 ∙

-- | Dot product between two vectors.
(·) :: ∀ n t. (KnownNumeric t, KnownNat n) =>
  Tensor '[n] t -> Tensor '[n] t -> Tensor '[] t
x · y = reduceSum0 (x ⊙ y)
infixl 7 ·

-- | 2-Norm of a vector
norm :: KnownBits t => KnownNat n
     => T '[n] (Flt t) -> Scalar (Flt t)
norm = frobNorm

-- | 2-Norm of a tensor
frobNorm :: KnownShape s => KnownBits t => T s (Flt t) -> Scalar (Flt t)
frobNorm = sqrt . reduceSumAll . square

normalize :: (KnownNat n, KnownBits t) =>
                   T '[n] (Flt t) -> T '[n] (Flt t)
normalize v = mapT (/ (norm v + epsilon)) v
  where epsilon = 1.0e-8

fillTriangular :: forall n l t.
                  (KnownNat n, KnownNat l, KnownNumeric t, (((l+l)-n) ~ (n*n)), n <= l)
               => Tensor '[l] t -> Tensor '[n,n] t
fillTriangular x = plusMinusAssoc @l @l @n #> tril 0 (inflate2 (concat0 x rr))
  where rr :: Tensor '[l - n] t
        rr = subIneq @l @n #> slice0 @0 @(l-n) (reverseT x) 


-- @lookupManyT def indices array@ lokup indices in array, returning def if the index is -1
lookupManyT :: forall s n t. KnownNat n => KnownShape s => (KnownNumeric t) => Scalar t -> T s Int32 -> T '[n] t -> T s t
lookupManyT def indices array =
  appRUnit @s #> mapTT @s (\idx -> where_ (equal idx (-1)) def (lookupT idx array)) indices


-- | A flexible upper-triangular matrix function: fill the upper triangle with l elements. 
fillUpperTriangular :: forall n l t. KnownNumeric t => KnownNat n => KnownNat l => T '[l] t -> T '[n,n] t
fillUpperTriangular x =
  zipWithTT @'[n,n]
  (\i j -> let idx :: Scalar Int32
               idx = ((i * (2 * n - i - 3)) `floorDiv` 2 + j - 1)

-- The index to lookup in the input array. It is computed from the formula:
-- Output[i,j] = (j-i-1) + ∑_k^(i-1) (n-k)
--                              
-- The term j-i-1 is the distance from the upper diagonal.
-- The sum is the number of elements in the previous rows
               
           in where_ (((j - i) `greaterThan` 0) `logicAnd` (idx `lessThan` l))
                     (lookupT idx x)
                     zeros)
    range0 
    range1 where

  n, l :: Scalar Int32
  n = constant (fromIntegral (natVal (Proxy @n)))
  l = constant (fromIntegral (natVal (Proxy @l)))
  
  -- "j" index
  range1 :: forall n m w. (KnownNat n, KnownNat m) => KnownBits w => T '[n,m] ('Typ 'Int w)
  range1 = broadcastT range

  -- "i" index
  range0 :: forall n m w. (KnownNat n, KnownNat m) => KnownBits w => T '[n,m] ('Typ 'Int w)
  range0 = transpose01 range1


-------------------------
-- Generic parameters

-- | Create a parameter and initialize it with a suitable default for its type. Control the exact initializer using 'parameter'.
parameterDefault :: forall p. ParamWithDefault p => String -> Gen p
parameterDefault name = parameter name defaultInitializer


-- flattenHTV :: KnownTyp t => All KnownShape xs => HTV t xs -> Tensor '[Sum (Ap (FMap CProduct) xs)] t
-- flattenHTV Unit = zeros
-- flattenHTV (F x :* xs) = concat0 (flattenAll x) (flattenHTV xs)

-- class CProduct (xs :: [Nat])
-- instance Fun CProduct where type Ap CProduct xs = Product xs

-- inflateHTV :: ∀ xs s t. (All KnownShape xs, KnownLen s, KnownLen xs) =>
--           Tensor '[Sum (Ap (FMap CProduct) xs)] t -> Gen (HTV t xs)
-- inflateHTV (T x) = do
--   v <- newVar
--   gen (v <> text " = " <> funcall "tf.split" [x, showShape' (prodshape @xs shapeSList), text "axis=0"])
--   return (mkArr @xs 0 shapeSList  v)
--   where mkArr :: forall zs. All KnownShape zs => Int -> SList zs -> DOC -> HTV t zs
--         mkArr _ LZ _ = Unit
--         mkArr i (LS _ n) v = F (unsafeReshape (T (v <> brackets (int i)) )):* mkArr (succ i) n v
--         prodshape :: forall zs. All KnownShape zs => SList zs -> [Integer]
--         prodshape LZ = []
--         prodshape (LS xx xs) = product (shapeToList' (shapeSListProxy xx)) : prodshape xs


-- -- | Gradient of wrt. given parameters.
-- grad' :: KnownLen xs => T s Float32 -> HHTV xs -> Gen (HHTV xs)
-- grad' (T y) vars = do
--  v <- newVar
--  v <-- funcall "tf.gradients" [y, list (htoList (hmap (\(Uncurry (T x)) -> K x) vars)) ]
--  return (mkArr 0 shapeSList v)
--   where mkArr :: forall xs. Int -> SList xs -> DOC -> HHTV xs
--         mkArr _ LZ _ = Unit
--         mkArr i (LS _ n) v = Uncurry (T (v <> brackets (int i))) :* mkArr (succ i) n v


================================================
FILE: TypedFlow/Types/Proofs.hs
================================================
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE CPP #-}
#if __GLASGOW_HASKELL__ >= 806
{-# LANGUAGE NoStarIsType #-}
#endif
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}

module TypedFlow.Types.Proofs where


import Prelude hiding (RealFrac(..))
import GHC.TypeLits
import Data.Proxy
import TypedFlow.Types hiding (T)
import Data.Type.Equality
import Unsafe.Coerce
import Data.Kind (Type)
class SingEq s where
  testEq :: forall a b. s a -> s b -> Maybe (a :~: b)

instance SingEq (Sat KnownNat) where
  testEq :: forall n m. Sat KnownNat n -> Sat KnownNat m -> Maybe (n :~: m)
  testEq = testNatEqual

natValS :: forall m. Sat KnownNat m -> Integer
natValS Sat = natVal (Proxy @m)

testNatEqual :: Sat KnownNat m -> Sat KnownNat n -> Maybe (m :~: n)
testNatEqual m n = if natValS m == natValS n then Just (unsafeCoerce Refl) else Nothing

instance SingEq f => SingEq (NP f) where
  testEq Unit Unit = Just Refl
  testEq (x :* xs) (y :* ys) = case (testEq x y, testEq xs ys) of
    (Just Refl, Just Refl) -> Just Refl
    _ -> Nothing
  testEq _ _ = Nothing

instance SingEq SKind where
  testEq SBool SBool = Just Refl
  testEq SInt SInt = Just Refl
  testEq SFloat SFloat = Just Refl
  testEq _ _ = Nothing

instance SingEq SNBits where
  testEq SB32 SB32 = Just Refl
  testEq SB64 SB64 = Just Refl
  testEq _ _ = Nothing

instance SingEq STyp where
  testEq (STyp k b Refl) (STyp k' b' Refl) = case (testEq k k', testEq b b') of
    (Just Refl, Just Refl) -> Just Refl
    _ -> Nothing

-- | Use a reified equality relation
(#>) :: (a :~: b) -> ((a ~ b) => k) -> k
Refl #> k = k
infixr 0 #>

-- | Use a reified arbitrary predicate
(?>) :: Sat constraint a -> (constraint a => k) -> k
Sat ?> k = k
infixr 0 ?>

-- | Use a reified arbitrary constraint
(??>) :: Dict constraint -> (constraint => k) -> k
Dict ??> k = k
infixr 0 ??>

productS :: forall s. SShape s -> Sat KnownNat (Product s)
productS s = knownSShape s ?>
             knownProduct @s ?>
             Sat

plusComm :: forall x y. (x + y) :~: (y + x)
plusComm = unsafeCoerce Refl

plusCommS :: forall x y px py. px x -> py y -> (x + y) :~: (y + x)
plusCommS _ _ = plusComm @x @y

plusAssoc :: forall x y z. (x + y) + z :~: x + (y + z)
plusAssoc = unsafeCoerce Refl

plusAssocS :: forall x y z px py pz. px x -> py y -> pz z -> ((x + y) + z) :~: (x + (y + z))
plusAssocS _ _ _ = plusAssoc @x @y @z

prodAssoc :: forall x y z. (x * y) * z :~: x * (y * z)
prodAssoc = unsafeCoerce Refl

prodAssocS :: forall x y z px py pz. px x -> py y -> pz z -> ((x * y) * z) :~: (x * (y * z))
prodAssocS _ _ _ = prodAssoc @x @y @z

prodCommS :: forall x y px py. px x -> py y -> (x * y) :~: (y * x)
prodCommS _ _ = unsafeCoerce Refl

termCancelation :: forall a b. (a + b) - b :~: a
termCancelation = plusMinusAssoc @a @b @b #> cancelation @b #> Refl

plusMinusAssoc :: forall x y z. (x + y) - z :~: x + (y - z)
plusMinusAssoc = unsafeCoerce Refl

cancelation :: (a - a) :~: 0
cancelation = unsafeCoerce Refl

plusMono :: forall a b. (a <=? (a+b)) :~: 'True
plusMono = unsafeCoerce Refl

succPos :: (1 <=? 1+j) :~: 'True
  -- CmpNat 0 (1 + n) :~: 'LT
succPos = unsafeCoerce Refl

succPosProx2 :: forall n proxy a. proxy n a -> (0 :<: (1+n))
succPosProx2 _ = succPos @n

prodHomo ::  forall x y. Product (x ++ y) :~: Product x * Product y
prodHomo = unsafeCoerce Refl

prodHomoS ::  forall x y px py. px x -> py y -> ((Product (x ++ y) :~: (Product x * Product y)))
prodHomoS _ _ = prodHomo @x @y

knownProduct' :: forall s f. All KnownNat s => NP f s -> Sat KnownNat (Product s)
knownProduct' Unit = Sat
knownProduct' (_ :* n) = knownProduct' n ?> Sat

knownProduct :: forall s. KnownShape s => Sat KnownNat (Product s)
knownProduct = knownProduct' @s typeSList

knownSumS :: forall s. NP (Sat KnownNat) s -> Sat KnownNat (Sum s)
knownSumS Unit = Sat
knownSumS (Sat :* n) = knownSumS n ?> Sat

knownSum' :: forall s f. All KnownNat s => NP f s -> Sat KnownNat (Sum s)
knownSum' proxies = knownSumS (allKnown' proxies)

knownSum :: forall s. KnownShape s => Sat KnownNat (Sum s)
knownSum = knownSum' @s typeSList

knownPlus :: forall m n. KnownNat m => KnownNat n => Sat KnownNat (m + n)
knownPlus = Sat

takeDrop :: forall s n. (PeanoNat n <= Length s) => (Take n s ++ Drop n s) :~: s
takeDrop = unsafeCoerce Refl

lengthHomo :: forall x y. Length (x ++ y) :~: Length x + Length y
lengthHomo = unsafeCoerce Refl

lengthHomoS :: forall x y proxyx proxyy. proxyx x -> proxyy y -> ((Length (x ++ y) :~: (Length x + Length y)))
lengthHomoS _ _ = lengthHomo @x @y

lengthInit :: forall s. (0 < Length s) => SList s -> ((Length (Init s) + 1) :~: Length s)
lengthInit x = lengthHomo @(Init s) @'[Last s] #> initLast x #> Refl

type a :<=: b = ((a <=? b):~: 'True)
type i :<: j = (i+1) :<=: j

incrPos :: forall x. 1 :<=: (x+1)
incrPos = unsafeCoerce Refl


subIneq :: forall x k. (x - k) :<=: x
subIneq = unsafeCoerce Refl

incrCong :: forall x y. ((x+1) ~ (y+1)) => x :~: y
incrCong = unsafeCoerce Refl

initLast :: forall s. {-(0 < Length s) => FIXME -} SList s -> ((Init s ++ '[Last s]) :~: s)
initLast Unit = error "initLast': does not hold on empty lists"
initLast ((:*) _ Unit) = Refl
initLast ((:*) _ ((:*) y ys)) = initLast ((:*) y ys) #> Refl

initLast' :: forall s. {-(0 < Length s) => FIXME -} KnownShape s => ((Init s ++ '[Last s]) :~: s)
initLast' = initLast (typeSList @s)

appRUnit :: forall s. (s ++ '[]) :~: s
appRUnit = unsafeCoerce Refl

appAssoc ::  ((xs ++ ys) ++ zs) :~: (xs ++ (ys ++ zs))
appAssoc = unsafeCoerce Refl

appAssocS :: forall xs ys zs proxy1 proxy2 proxy3.
             proxy1 xs -> proxy2 ys -> proxy3 zs -> (((xs ++ ys) ++ zs) :~: (xs ++ (ys ++ zs)))
appAssocS _ _ _  = appAssoc @xs @ys @zs


knownLast' :: All KnownNat s => SList s -> (KnownNat (Last s) => k) -> k
knownLast' Unit _ = error "knownLast: does not hold on empty lists"
knownLast' ((:*) _ Unit) k = k
knownLast' ((:*) _ ((:*) y xs)) k = knownLast' ((:*) y xs) k

knownLast :: forall s k. KnownShape s => (KnownNat (Last s) => k) -> k
knownLast = knownLast' @s typeSList

knownInit' :: All KnownNat s => SList s -> Sat KnownShape (Init s)
knownInit' Unit = error "knownLast: does not hold on empty lists"
knownInit' ((:*) _ Unit) = Sat
knownInit' ((:*) _ ((:*) y xs)) = knownInit' ((:*) y xs) ?> Sat

knownInit :: forall s. KnownShape s => Sat KnownShape (Init s)
knownInit = knownInit' @s typeSList

knownTail' :: forall x s k. All KnownNat s => SList (x ': s) -> (KnownShape s => k) -> k
knownTail' ((:*) _ Unit) k = k
knownTail' ((:*) _ ((:*) y xs)) k = knownTail' ((:*) y xs) k

knownTail :: forall s x xs k. (s ~ (x ': xs), KnownShape s) => (KnownShape xs => k) -> k
knownTail = knownTail' @x @xs typeSList

knownAppendS :: forall s t pt. (All KnownNat s, KnownShape t) => SList s -> pt t -> Sat KnownShape (s ++ t)
knownAppendS Unit _t = Sat
knownAppendS ((:*) _ n) t = knownAppendS n t ?> Sat

knownAppend :: forall s t.  (KnownShape s, KnownShape t) => Sat KnownShape (s ++ t)
knownAppend = knownAppendS (typeSList @s) (Proxy @t)


-- knownFmap' :: forall f xs. SList xs -> SList (Ap (FMap f) xs)
-- knownFmap' Unit = Unit
-- knownFmap' ((:*) x n) = (:*) Proxy (knownFmap' @f n)

knownSList :: NP proxy xs -> Sat KnownLen xs
knownSList Unit = Sat
knownSList (_ :* n) = knownSList n ?> Sat

knownSShape :: SShape xs -> Sat KnownShape xs
knownSShape Unit = Sat
knownSShape ((:*) Sat s) = knownSShape s ?> Sat

data DimExpr (a :: Nat) (x :: Nat) (b :: Nat) where
  ANat :: Sat KnownNat x -> DimExpr a x (a * x)
  (:*:) :: DimExpr a x b -> DimExpr b y c -> DimExpr a (x*y) c

knownOutputDim :: forall a x b. Sat KnownNat a -> DimExpr a x b -> Sat KnownNat b
knownOutputDim a (ANat x) = satMul a x
knownOutputDim a (x :*: y) = knownOutputDim (knownOutputDim a x) y

dimSat :: DimExpr a x b -> Sat KnownNat x
dimSat (ANat s) = s
dimSat (x :*: y) = dimSat x `satMul` dimSat y

normDim :: forall ws xs ys. DimExpr ws xs ys -> (ws * xs) :~: ys
normDim (ANat _) = Refl
normDim (a :*:b) = normDim a #>
                   normDim b #>
                   prodAssocS (Proxy @ws) (dimSat a) (dimSat b) #>
                   Refl

data ShapeExpr (a :: Nat) (x :: Shape) (b :: Nat) where
  Single :: DimExpr a x b -> ShapeExpr a '[x] b
  AShape :: SShape x -> ShapeExpr a x (a * Product x)
  (:++:) :: ShapeExpr a x b -> ShapeExpr b y c -> ShapeExpr a (x++y) c

infixr 5 :++:
infixr 5 *:!
infixr 5 !:*

(!:*) :: DimExpr a x b -> ShapeExpr b xs c -> ShapeExpr a (x ': xs) c
x !:* xs = Single x :++: xs

(*:!) :: ShapeExpr a xs b -> DimExpr b x c -> ShapeExpr a (xs ++ '[x]) c
xs *:! x = xs :++: Single x

exprSShape :: forall a x b. ShapeExpr a x b -> SShape x
exprSShape (AShape s) = s
exprSShape (Single x) = dimSat x ?> typeSShape
exprSShape (x :++: y) = exprSShape x .+. exprSShape y

normShape :: forall ws xs ys. ShapeExpr ws xs ys -> (ws * Product xs) :~: ys
normShape (Single x) = normDim x
normShape (AShape _) = Refl
normShape (l :++: r) = normShape l #>
                       normShape r #>
                       prodHomoS (exprSShape l) (exprSShape r) #>
                       prodAssocS (Proxy @ws) (productS (exprSShape l)) (productS (exprSShape r)) #>
                       Refl
        -- r :: normShape b y ys ----> (b * y) ~ ys   (1)
        -- l :: normShape ws x b ----> (ws * x) ~ b   (2)
        -- subst (2) in (1): ((ws * x) * y) ~ ys
        -- assoc: (ws * (x * y)) ~ ys

decideProductEq1 :: forall xs zs. ShapeExpr 1 xs zs -> Product xs :~: zs
decideProductEq1 a  = case normShape a of Refl -> Refl

type ShapeX = ShapeExpr 1

decideProductEq :: ShapeExpr 1 xs zs -> ShapeExpr 1 ys zs -> Product xs :~: Product ys
decideProductEq l r = case decideProductEq1 l of
                        Refl -> case decideProductEq1 r of
                          Refl -> Refl


unsafePositive :: (1 <=? n) :~: 'True
unsafePositive = unsafeCoerce Ref
Download .txt
gitextract_rnbmpni2/

├── .gitignore
├── LICENSE
├── Makefile
├── README.org
├── TypedFlow/
│   ├── Abstract.hs
│   ├── Broadcast.hs
│   ├── Haskell.hs
│   ├── Layers/
│   │   ├── Core.hs
│   │   ├── RNN/
│   │   │   ├── Attention.hs
│   │   │   ├── Base.hs
│   │   │   └── Cells.hs
│   │   └── RNN.hs
│   ├── Layers.hs
│   ├── Learn.hs
│   ├── Memo.hs
│   ├── Memo2.hs
│   ├── Models/
│   │   ├── Topic.hs
│   │   └── Transformer.hs
│   ├── Python.hs
│   ├── TF.hs
│   ├── Types/
│   │   └── Proofs.hs
│   └── Types.hs
├── TypedFlow.hs
├── cabal.project
├── docs/
│   ├── HOT.org
│   └── Talk.org
├── examples/
│   ├── agreement/
│   │   └── Aggr.hs
│   ├── mnist/
│   │   ├── MNIST.hs
│   │   ├── Makefile
│   │   ├── main.py
│   │   └── mnist_model.py
│   └── seq2seq/
│       ├── GenTr.hs
│       ├── Makefile
│       ├── Seq2Seq.hs
│       ├── main.py
│       └── shell.nix
├── styx.yaml
├── typedflow.cabal
└── typedflow_rts.py
Download .txt
SYMBOL INDEX (35 symbols across 4 files)

FILE: examples/mnist/main.py
  function train_generator (line 17) | def train_generator(batch_size):

FILE: examples/mnist/mnist_model.py
  function mkModel (line 2) | def mkModel():
  function runModel_fn (line 65) | def runModel_fn(training_placeholder,

FILE: examples/seq2seq/main.py
  function pad (line 23) | def pad(ws): return (ws + ' '*(MAXLEN - len(ws)))
  function encode (line 25) | def encode(s):
  function decode (line 29) | def decode(s): return "".join([indices_char[c] for c in list(s)])
  function pad_right (line 31) | def pad_right(sentence): return (MAXLEN - len(sentence)) * " " + sentence
  function pad_left (line 32) | def pad_left(sentence): return  sentence + (MAXLEN - len(sentence)) * " "
  function source_input_conversion (line 34) | def source_input_conversion(s):
  function target_input_conversion (line 37) | def target_input_conversion(sentence):
  function target_output_conversion (line 40) | def target_output_conversion(sentence):
  function sentence_target_weights (line 43) | def sentence_target_weights(sentence):
  function map (line 48) | def map(f,l):
  function make_examples (line 51) | def make_examples(l):
  function s2s_generator (line 59) | def s2s_generator(src_len,src_in,tgt_in,tgt_out,tgt_weights):
  function my_sample (line 71) | def my_sample(l,n):
  function printer (line 85) | def printer(x):
  function translate (line 90) | def translate(s):
  function translate_cb (line 98) | def translate_cb(values):

FILE: typedflow_rts.py
  function cuda_use_device (line 13) | def cuda_use_device(n):
  function find_free_cuda_device (line 19) | def find_free_cuda_device():
  function cuda_use_one_free_device (line 43) | def cuda_use_one_free_device():
  function bilist_generator (line 51) | def bilist_generator(l):
  function bilist_generator_transposed (line 68) | def bilist_generator_transposed(model,l):
  function dict_generator (line 91) | def dict_generator (xs):
  function initialize_params (line 102) | def initialize_params (session,model):
  function train (line 114) | def train (optimizer, model_static, model_fn,
  function StopWhenValidationGetsWorse (line 189) | def StopWhenValidationGetsWorse(patience = 1):
  function StopWhenAccurate (line 206) | def StopWhenAccurate(phase="val",error_rate = .01):
  function Every (line 213) | def Every(n,f):
  function Save (line 223) | def Save(sess,saver,ckptfile):
  function evaluate (line 235) | def evaluate (model_static, model_fn, xs, result="y_"):
  function beam_translate (line 277) | def beam_translate(session, model, k, x, xlen, start_symbol, stop_symbol...
  function save (line 307) | def save(model_static, file):
  function load (line 314) | def load(model_static, file):
Condensed preview — 39 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (290K chars).
[
  {
    "path": ".gitignore",
    "chars": 309,
    "preview": ".styx\n*~\ndist\ndist-*\ncabal-dev\n*.o\n*.hi\n*.chi\n*.chs.h\n*.dyn_o\n*.dyn_hi\n.hpc\n.hsenv\n.cabal-sandbox/\ncabal.sandbox.config\n"
  },
  {
    "path": "LICENSE",
    "chars": 7651,
    "preview": "                   GNU LESSER GENERAL PUBLIC LICENSE\n                       Version 3, 29 June 2007\n\n Copyright (C) 2007"
  },
  {
    "path": "Makefile",
    "chars": 162,
    "preview": "\nviewdoc: dist/doc/html/typedflow/index.html\n\txdg-open $<\n\ndist/doc/html/typedflow/index.html:\n\tstyx cabal -- haddock --"
  },
  {
    "path": "README.org",
    "chars": 1374,
    "preview": "#+TITLE: TypedFlow\n\nTypedFlow is a typed, higher-order frontend to [[http://www.tensorflow.org][TensorFlow]] and a\nhigh-"
  },
  {
    "path": "TypedFlow/Abstract.hs",
    "chars": 30231,
    "preview": "{-# LANGUAGE InstanceSigs #-}\n{-|\nModule      : TypedFlow.Abstract\nDescription : Abstract Tensor representations\nCopyrig"
  },
  {
    "path": "TypedFlow/Broadcast.hs",
    "chars": 24678,
    "preview": "{-# LANGUAGE InstanceSigs #-}\n{-|\nModule      : TypedFlow.Abstract\nDescription : Abstract Tensor representations\nCopyrig"
  },
  {
    "path": "TypedFlow/Haskell.hs",
    "chars": 15350,
    "preview": "{-|\nModule      : TypedFlow.Haskell\nDescription : Generation of computation graph using tensorflow haskell. \nCopyright  "
  },
  {
    "path": "TypedFlow/Layers/Core.hs",
    "chars": 9208,
    "preview": "{-|\nModule      : TypedFlow.Layers.Core\nDescription : Core layers and combinators.\nCopyright   : (c) Jean-Philippe Berna"
  },
  {
    "path": "TypedFlow/Layers/RNN/Attention.hs",
    "chars": 6071,
    "preview": "{-|\nModule      : TypedFlow.Layers.RNN.Attention\nDescription : Attention combinators to be used with RNN cells\nCopyright"
  },
  {
    "path": "TypedFlow/Layers/RNN/Base.hs",
    "chars": 11000,
    "preview": "{-|\nModule      : TypedFlow.Layers.RNN.Base\nDescription : RNN cells, layers and combinators.\nCopyright   : (c) Jean-Phil"
  },
  {
    "path": "TypedFlow/Layers/RNN/Cells.hs",
    "chars": 5134,
    "preview": "{-# LANGUAGE UndecidableInstances #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE ViewPatterns #-}\n{-# LANGUAGE Allow"
  },
  {
    "path": "TypedFlow/Layers/RNN.hs",
    "chars": 488,
    "preview": "{-|\nModule      : TypedFlow.Layers.RNN\nDescription : RNN cells, layers and combinators.\nCopyright   : (c) Jean-Philippe "
  },
  {
    "path": "TypedFlow/Layers.hs",
    "chars": 159,
    "preview": "\nmodule TypedFlow.Layers\n  (module  TypedFlow.Layers.Core\n  ,module  TypedFlow.Layers.RNN\n  ) where\n\nimport TypedFlow.La"
  },
  {
    "path": "TypedFlow/Learn.hs",
    "chars": 11514,
    "preview": "{-# LANGUAGE LambdaCase #-}\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE PatternSynonyms #-}\n{-|\nModule      : TypedF"
  },
  {
    "path": "TypedFlow/Memo.hs",
    "chars": 3615,
    "preview": "{-# LANGUAGE TypeInType #-}\n{-# LANGUAGE PolyKinds #-}\n{-# LANGUAGE KindSignatures #-}\n{-# LANGUAGE ScopedTypeVariables "
  },
  {
    "path": "TypedFlow/Memo2.hs",
    "chars": 11624,
    "preview": "{-# LANGUAGE GeneralizedNewtypeDeriving #-}\n{-# LANGUAGE TypeOperators #-}\n{-# LANGUAGE LambdaCase #-}\n{-# LANGUAGE Reco"
  },
  {
    "path": "TypedFlow/Models/Topic.hs",
    "chars": 6871,
    "preview": "{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat"
  },
  {
    "path": "TypedFlow/Models/Transformer.hs",
    "chars": 6243,
    "preview": "{-# LANGUAGE PartialTypeSignatures #-}\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE MultiParamTypeClasses #-}\n{-# OPT"
  },
  {
    "path": "TypedFlow/Python.hs",
    "chars": 21832,
    "preview": "{-# LANGUAGE ViewPatterns #-}\n{-|\nModule      : TypedFlow.Python\nDescription : Python-generation Functions \nCopyright   "
  },
  {
    "path": "TypedFlow/TF.hs",
    "chars": 13182,
    "preview": "{-# LANGUAGE InstanceSigs #-}\n{-|\nModule      : TypedFlow.TF\nDescription : Binding to tensorflow functions\nCopyright   :"
  },
  {
    "path": "TypedFlow/Types/Proofs.hs",
    "chars": 15038,
    "preview": "{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# LANGUAGE ConstraintKinds #-}\n{-# LANGUAGE DataKinds #-}\n{-# LANGUAGE DeriveFold"
  },
  {
    "path": "TypedFlow/Types.hs",
    "chars": 30747,
    "preview": "{-# LANGUAGE QuantifiedConstraints #-}\n{-# LANGUAGE CPP #-}\n#if __GLASGOW_HASKELL__ >= 806\n{-# LANGUAGE NoStarIsType #-}"
  },
  {
    "path": "TypedFlow.hs",
    "chars": 553,
    "preview": "{-|\nModule      : TypedFlow\nDescription : Higher-Order Typed Binding to TensorFlow and Deep Learning Library\nCopyright  "
  },
  {
    "path": "cabal.project",
    "chars": 30,
    "preview": "packages:\n  ./typedflow.cabal\n"
  },
  {
    "path": "docs/HOT.org",
    "chars": 4172,
    "preview": "#+TITLE: TypedFlow: The HOT parts\n#+AUTHOR: Jean-Philippe Bernardy, University of Gothenburg\n\nTensorFlow™ is an open sou"
  },
  {
    "path": "docs/Talk.org",
    "chars": 4756,
    "preview": "#+TITLE: TypedFlow: A library for higher-order typed deep learning\n#+AUTHOR: Jean-Philippe Bernardy, University of Gothe"
  },
  {
    "path": "examples/agreement/Aggr.hs",
    "chars": 2418,
    "preview": "{-# LANGUAGE ApplicativeDo #-}\n{-# LANGUAGE ViewPatterns #-}\n{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# OPTIONS_GHC -fplug"
  },
  {
    "path": "examples/mnist/MNIST.hs",
    "chars": 1749,
    "preview": "{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE ApplicativeDo #-}\n{-# LANGUAGE DataKinds #-}\n{-# "
  },
  {
    "path": "examples/mnist/Makefile",
    "chars": 182,
    "preview": "test: mnist_model.py main.py\n\tnix-shell ../seq2seq/shell.nix --run \"python main.py\"\n\nmnist_model.py: MNIST.hs\n\tnix-shell"
  },
  {
    "path": "examples/mnist/main.py",
    "chars": 678,
    "preview": "import sys\nsys.path.append('../..') # so we can see the rts.\n\nimport typedflow_rts as tyf\nimport tensorflow as tf\nimport"
  },
  {
    "path": "examples/mnist/mnist_model.py",
    "chars": 7053,
    "preview": "import tensorflow as tf\ndef mkModel():\n  #shape: [25, 32]\n  var10000=tf.random.uniform([25, 32],\n                       "
  },
  {
    "path": "examples/seq2seq/GenTr.hs",
    "chars": 1790,
    "preview": "import Control.Applicative\nimport Test.QuickCheck.Gen\nimport Data.List\nimport Data.Array\ndata Abs a = Bin a (Abs a) (Abs"
  },
  {
    "path": "examples/seq2seq/Makefile",
    "chars": 256,
    "preview": "test: s2s.py synthtrees.txt main.py\n\tnix-shell --run \"python main.py\"\n\ns2s.py: Seq2Seq.hs\n\tnix-shell ../../.styx/shell.n"
  },
  {
    "path": "examples/seq2seq/Seq2Seq.hs",
    "chars": 4583,
    "preview": "{-# LANGUAGE AllowAmbiguousTypes #-}\n{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}\n{-# LANGUAGE DataKinds #-"
  },
  {
    "path": "examples/seq2seq/main.py",
    "chars": 3394,
    "preview": "import sys\nsys.path.append('../..') # so we can see the rts.\n\nimport typedflow_rts as tyf\nimport tensorflow as tf\nimport"
  },
  {
    "path": "examples/seq2seq/shell.nix",
    "chars": 772,
    "preview": "{ bootstrap ? import <nixpkgs> {} }:\nlet nixpkgs_source = fetchTarball https://github.com/NixOS/nixpkgs/archive/nixos-20"
  },
  {
    "path": "styx.yaml",
    "chars": 325,
    "preview": "local-packages:\n  typedflow:\n    location: .\n\nnix-deps:\n    - QuickCheck\n    - hscolour\n\n# non-haskell-deps:\n#     - gli"
  },
  {
    "path": "typedflow.cabal",
    "chars": 1910,
    "preview": "name:           typedflow\nversion:        0.9\ncategory:       Deep Learning\nsynopsis:       Typed frontend to TensorFlow"
  },
  {
    "path": "typedflow_rts.py",
    "chars": 11899,
    "preview": "import tensorflow as tf\nimport numpy as np\nimport sys\nfrom time import time\nimport os\nimport random\n\n###################"
  }
]

About this extraction

This page contains the full source code of the GU-CLASP/TypedFlow GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 39 files (272.5 KB), approximately 86.0k tokens, and a symbol index with 35 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!