Full Code of chrisdone/duet for AI

master db305103f76f cached
53 files
236.5 KB
61.7k tokens
1 requests
Download .txt
Showing preview only (251K chars total). Download the full file or copy to clipboard to get everything.
Repository: chrisdone/duet
Branch: master
Commit: db305103f76f
Files: 53
Total size: 236.5 KB

Directory structure:
gitextract_m0lahraa/

├── .gitignore
├── Dockerfile
├── LICENSE.md
├── README.md
├── app/
│   └── Main.hs
├── duet.cabal
├── examples/
│   ├── ack.hs
│   ├── arith.hs
│   ├── bound.hs
│   ├── builtins.hs
│   ├── classes.hs
│   ├── fac.hs
│   ├── factorial.hs
│   ├── folds-strictness.hs
│   ├── folds.hs
│   ├── functor-class.hs
│   ├── gabriel-eq-reason.hs
│   ├── good.hs
│   ├── integers.hs
│   ├── lists.hs
│   ├── monad.hs
│   ├── monoid.hs
│   ├── ord.hs
│   ├── parser.hs
│   ├── pattern-matching.hs
│   ├── placeholders.hs
│   ├── prelude.hs
│   ├── seq.hs
│   ├── sicp.hs
│   ├── simple-class.hs
│   ├── state.hs
│   ├── strict-folds.hs
│   ├── string-pats.hs
│   ├── string-substring.hs
│   ├── syntax-buffet.hs
│   └── terminal.hs
├── src/
│   ├── Control/
│   │   └── Monad/
│   │       └── Supply.hs
│   └── Duet/
│       ├── Context.hs
│       ├── Errors.hs
│       ├── Infer.hs
│       ├── Parser.hs
│       ├── Printer.hs
│       ├── Renamer.hs
│       ├── Resolver.hs
│       ├── Setup.hs
│       ├── Simple.hs
│       ├── Stepper.hs
│       ├── Supply.hs
│       ├── Tokenizer.hs
│       └── Types.hs
├── stack.yaml
└── test/
    ├── Main.hs
    └── Spec.hs

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.stack-work


================================================
FILE: Dockerfile
================================================
FROM frolvlad/alpine-gcc as base

RUN apk add --no-cache ghc curl git

RUN curl -L https://github.com/nh2/stack/releases/download/v1.6.5/stack-prerelease-1.9.0.1-x86_64-unofficial-fully-static-musl > /usr/bin/stack

RUN chmod +x /usr/bin/stack

RUN git clone https://github.com/chrisdone/duet.git --depth 1 && cd duet && git checkout 186d4dbf85f23e28862fce7e8160adddfdb8d36f
RUN cd duet && stack update
RUN apk add --no-cache zlib-dev
RUN cd duet && stack build --system-ghc --dependencies-only

RUN cd duet && git pull && git checkout f6c19caf0cb9182dae665ff47c68c27001763fd9
RUN cd duet && stack install --system-ghc --fast

FROM alpine:3.9
RUN apk add --no-cache gmp libffi

COPY --from=base /root/.local/bin/duet /usr/bin/duet

ENTRYPOINT ["duet"]


================================================
FILE: LICENSE.md
================================================
*Duet* is Copyright (c) Chris Done 2017.

*Typing Haskell in Haskell*, which provides the groundwork for Duet's
type system, is Copyright (c) Mark P Jones and the Oregon Graduate
Institute of Science and Technology, 1999-2000.

All rights reserved, and is distributed as free software under the
following license.

Redistribution and use in source and binary forms, with or
without modification, are permitted provided that the following
conditions are met:

- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.

- Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.

- Neither name of the copyright holders nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND THE
CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR THE
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


================================================
FILE: README.md
================================================
# <img src=images/duet.svg height=36> Duet

A tiny language, a subset of Haskell (with type classes) aimed at aiding teachers teach Haskell

## Run

Running code in Duet literally performs one substitution step at
time. For example, evaluating `(\x -> x + 5) (2 * 3)`, we get:

``` haskell
$ duet run demo.hs
(\x -> x + 5) (2 * 3)
(2 * 3) + 5
6 + 5
11
```

Note that this demonstrates basic argument application and non-strictness.

## Docker run

Run with the docker distribution, to easily run on any platform:

    $ docker run -it -v $(pwd):/w -w /w chrisdone/duet run foo.hs

(This should work on Linux, OS X or Windows PowerShell.)

The image is about 11MB, so it's quick to download.

## Differences from Haskell

See also the next section for a complete example using all the
available syntax.

* Duet is non-strict, but is not lazy. There is no sharing and no
  thunks.
* No `module` or `import` module system whatsoever.
* No `let` syntax, no parameters in definitions e.g. `f x = ..` you
  must use a lambda. Representing `let` in the stepper presents a
  design challenge not currently met.
* Kinds `*` are written `Type`: e.g. `class Functor (f :: Type -> Type)`.
* Kind inference is not implemented, so if you want a kind other than
  `Type` (aka `*` in Haskell), you have to put a kind signature on the
  type variable.
* Indentation is stricter, a case's alts must be at a column larger
  than the `case`.
* Duet does not have `seq`, but it does have bang patterns in
  cases. `case x of !x -> ..` is a perfectly legitimate way to force a
  value.
* Infix operators are stricter: an infix operator must have spaces
  around it. You **cannot** have more than one operator without
  parentheses, therefore operator precedence does not come into play
  in Duet (this is intentional). This also permits you to write `-5`
  without worrying about where it rests.
* Superclasses are not supported.
* Operator definitions are not supported.
* There is only `Integer` and `Rational` number types: they are
  written as `1` or `1.0`.
* Any `_` or `_foo` means "hole" and the interpreter does not touch
  them, it continues performing rewrites without caring. This is good
  for teaching.
* There is no standard `Prelude`. The only defined base types are:
  * String
  * Char
  * Integer
  * Rational
  * Bool
* You don't need a `Show` instance to inspect values; the interpreter
  shows them as they are, including lambdas.

View `examples/syntax-buffet.hs` for an example featuring all the
syntax supported in Duet.

## Print built-in types and classes

To print all types (primitive or otherwise), run:

    $ duet types

Example output:

```haskell
data Bool
  = True
  | False
data String
data Integer
data Rational
```

For classes and the instances of each class:

    $ duet classes

Example output:

```haskell
class Num a where
  plus :: forall a. (a -> a -> a)
  times :: forall a. (a -> a -> a)
instance Num Rational
instance Num Integer

class Neg a where
  negate :: forall a. (a -> a -> a)
  subtract :: forall a. (a -> a -> a)
  abs :: forall a. (a -> a)
instance Neg Rational
instance Neg Integer

class Fractional a where
  divide :: forall a. (a -> a -> a)
  recip :: forall a. (a -> a)
instance Fractional Rational

class Monoid a where
  append :: forall a. (a -> a -> a)
  empty :: forall a. a
instance Monoid String

class Slice a where
  drop :: forall a. (Integer -> a -> a)
  take :: forall a. (Integer -> a -> a)
instance Slice String
```

## String operations

Strings are provided as packed opaque literals. You can unpack them
via the `Slice` class:

```haskell
class Slice a where
  drop :: Integer -> a -> a
  take :: Integer -> a -> a
```

You can append strings using the `Monoid` class:

```haskell
class Monoid a where
  append :: a -> a -> a
  empty :: a
```

The `String` type is an instance of these classes.

``` haskell
main = append (take 2 (drop 7 "Hello, World!")) "!"
```

Evaluates strictly because it's a primop:

``` haskell
append (take 2 (drop 7 "Hello, World!")) "!"
append (take 2 "World!") "!"
append "Wo" "!"
"Wo!"
```

You can use this type and operations to teach parsers.

## I/O

Basic terminal input/output is supported.

For example,

    $ duet run examples/terminal.hs --hide-steps
    Please enter your name:
    Chris
    Hello, Chris

And with steps:

    $ duet run examples/terminal.hs
    PutStrLn "Please enter your name: " (GetLine (\line -> PutStrLn (append "Hello, " line) (Pure 0)))
    Please enter your name:
    GetLine (\line -> PutStrLn (append "Hello, " line) (Pure 0))
    Chris
    (\line -> PutStrLn (append "Hello, " line) (Pure 0)) "Chris"
    PutStrLn (append "Hello, " "Chris") (Pure 0)
    Hello, Chris
    Pure 0

How does this work? Whenever the following code is seen in the
stepper:

```haskell
PutStrLn "Please enter your name: " <next>
```

The string is printed to stdout with `putStrLn`, and the `next`
expression is stepped next.

Whenever the following code is seen:

``` haskell
GetLine (\line -> <next>)
```

The stepper runs `getLine` and feeds the resulting string into the
stepper as:

```haskell
(\line -> <next>) "The line"
```

This enables one to write an example program like this:

``` haskell
data Terminal a
 = GetLine (String -> Terminal a)
 | PutStrLn String (Terminal a)
 | Pure a

main =
  PutStrLn
    "Please enter your name: "
    (GetLine (\line -> PutStrLn (append "Hello, " line) (Pure 0)))
```


================================================
FILE: app/Main.hs
================================================
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}

-- |

import           Control.Monad.Catch
import           Control.Monad.Logger
import           Control.Monad.Supply
import           Control.Monad.Writer
import qualified Data.Map.Strict as M
import           Data.Semigroup ((<>))
import           Duet.Context
import           Duet.Errors
import           Duet.Infer
import           Duet.Parser
import           Duet.Printer
import           Duet.Renamer
import           Duet.Setup
import           Duet.Simple
import           Duet.Stepper
import           Duet.Types
import           Options.Applicative.Simple
import           System.IO

data Run = Run
  { runInputFile :: FilePath
  , runMainIs :: String
  , runConcise :: Bool
  , runNumbered :: Bool
  , runSteps :: Maybe Integer
  , runHideSteps :: Bool
  } deriving (Show)

main :: IO ()
main = do
  hSetBuffering stdout LineBuffering
  hSetBuffering stdin LineBuffering
  ((), cmd) <-
    simpleOptions
      "1.0"
      "Duet interpreter"
      "This is the interpreter for the Duet mini-Haskell educational language"
      (pure ())
      (do addCommand "types" "Print types in scope" runTypesPrint (pure ())
          addCommand "classes" "Print types in scope" runClassesPrint (pure ())
          addCommand
            "run"
            "Run the given program source"
            runProgram
            (Run <$>
             strArgument
               (metavar "FILEPATH" <> help "The .hs file to interpret") <*>
             strOption
               (long "main" <> metavar "NAME" <> help "The main value to run" <>
                value "main") <*>
             flag False True (long "concise" <> help "Concise view") <*>
             flag False True (long "numbered" <> help "Number outputs") <*>
             optional
               (option
                  auto
                  (long "steps" <> short 'n' <> metavar "steps" <>
                   help "Maximum number of steps to run (default: unlimited)")) <*>
             flag
               False
               True
               (long "hide-steps" <> help "Do not print the steps to stdout")))
  cmd

runTypesPrint :: () -> IO ()
runTypesPrint _ = do
  builtins <- evalSupplyT (setupEnv mempty []) [1 ..]
  putStrLn
    (printDataType
       defaultPrint
       (builtinsSpecialTypes builtins)
       (specialTypesBool (builtinsSpecialTypes builtins)))
  when
    False
    (putStrLn
       (printTypeConstructorOpaque
          defaultPrint
          (specialTypesChar (builtinsSpecialTypes builtins))))
  putStrLn
    (printTypeConstructorOpaque
       defaultPrint
       (specialTypesString (builtinsSpecialTypes builtins)))
  putStrLn
    (printTypeConstructorOpaque
       defaultPrint
       (specialTypesInteger (builtinsSpecialTypes builtins)))
  putStrLn
    (printTypeConstructorOpaque
       defaultPrint
       (specialTypesRational (builtinsSpecialTypes builtins)))
  where
    printTypeConstructorOpaque p = ("data " ++) . printTypeConstructor p

runClassesPrint :: () -> IO ()
runClassesPrint _ = do
  builtins <- evalSupplyT (setupEnv mempty []) [1 ..]
  mapM_
    (putStrLn . (++ "\n") . printClass defaultPrint (builtinsSpecialTypes builtins))
    (M.elems (builtinsTypeClasses builtins))

runProgram :: Run -> IO ()
runProgram run@Run {..} = do
  catch
    (catch
       (runNoLoggingT
          (evalSupplyT
             (do decls <- liftIO (parseFile runInputFile)
                 (binds, ctx) <- createContext decls
                 things <-
                   execWriterT
                     (runStepperIO
                        run
                        runSteps
                        ctx
                        (fmap (fmap typeSignatureA) binds)
                        runMainIs)
                 pure things)
             [1 ..]))
       (putStrLn . displayContextException))
    (putStrLn . displayParseException)

-- | Run the substitution model on the code.
runStepperIO ::
     forall m. (MonadSupply Int m, MonadThrow m, MonadIO m)
  => Run
  -> Maybe Integer
  -> Context Type Name Location
  -> [BindGroup Type Name Location]
  -> String
  -> m ()
runStepperIO Run {..} maxSteps ctx bindGroups' i = do
  e0 <- lookupNameByString i bindGroups'
  loop 1 "" e0
  where
    loop :: Integer -> String -> Expression Type Name Location -> m ()
    loop count lastString e = do
      e' <- expandSeq1 ctx bindGroups' e
      let string = printExpression (defaultPrint) e
      when
        (string /= lastString && not runHideSteps)
        (if cleanExpression e || not runConcise
           then liftIO
                  (putStrLn
                     ((if runNumbered
                         then "[" ++ show count ++ "]\n"
                         else "") ++
                      printExpression defaultPrint e))
           else pure ())
      e'' <- pickUpIO e'
      if (fmap (const ()) e'' /= fmap (const ()) e) &&
         case maxSteps of
           Just top -> count < top
           Nothing -> True
        then do
          newE <-
            renameExpression
              (contextSpecials ctx)
              (contextScope ctx)
              (contextDataTypes ctx)
              e''
          loop (count + 1) string newE
        else pure ()

pickUpIO :: MonadIO m => Expression t Name l -> m (Expression t Name l)
pickUpIO =
  \case
    ApplicationExpression _ (ApplicationExpression _ (ConstructorExpression _ (ConstructorName _ "PutStrLn")) (LiteralExpression _ (StringLiteral toBePrinted))) next -> do
      liftIO (putStrLn toBePrinted)
      pure next
    ApplicationExpression l (ConstructorExpression _ (ConstructorName _ "GetLine")) func -> do
      inputString <- liftIO getLine
      pure (ApplicationExpression l func (LiteralExpression l (StringLiteral inputString)))
    e -> pure e

-- | Filter out expressions with intermediate case, if and immediately-applied lambdas.
cleanExpression :: Expression Type i l -> Bool
cleanExpression =
  \case
    CaseExpression {} -> False
    IfExpression {} -> False
    e0
      | (LambdaExpression {}, args) <- fargs e0 -> null args
    ApplicationExpression _ f x -> cleanExpression f && cleanExpression x
    _ -> True


================================================
FILE: duet.cabal
================================================
name:
  duet
version:
  0.0.2
cabal-version:
  >=1.10
build-type:
  Simple
maintainer:
  chrisdone@gmail.com
synopsis:
  A tiny language, a subset of Haskell (with type classes) aimed at aiding teachers to teach Haskell
description:
  A tiny language, a subset of Haskell (with type classes) aimed at aiding teachers to teach Haskell
license: BSD3
extra-source-files: README.md, LICENSE.md

library
  hs-source-dirs:
    src
  build-depends:
    base >= 4.5 && < 5,
    containers,
    mtl,
    exceptions,
    parsec,
    text,
    edit-distance,
    deepseq,
    aeson,
    syb,
       monad-logger
  ghc-options:
    -Wall
  default-language:
    Haskell2010
  exposed-modules:
    Duet.Infer
    Duet.Types
    Duet.Parser
    Duet.Printer
    Duet.Tokenizer
    Duet.Renamer
    Duet.Resolver
    Duet.Stepper
    Duet.Errors
    Duet.Supply
    Duet.Context
    Duet.Setup
    Duet.Simple
    Control.Monad.Supply

test-suite duet-test
  type: exitcode-stdio-1.0
  main-is: Spec.hs
  hs-source-dirs: test
  ghc-options: -Wall -O0
  default-language: Haskell2010
  build-depends:
      base >= 4.5 && < 5, duet,
    containers,
    mtl,
    exceptions,
    parsec,
    text,
    edit-distance,
    deepseq,
    aeson,
    syb,
    hspec,
    monad-logger

executable duet
  main-is: Main.hs
  hs-source-dirs: app
  ghc-options: -Wall
  default-language: Haskell2010
  build-depends:
      base >= 4.5 && < 5, duet,
    containers,
    mtl,
    exceptions,
    text,
    deepseq,
    aeson,
    syb,
    monad-logger,
    optparse-simple


================================================
FILE: examples/ack.hs
================================================
data Tuple a b = Tuple a b

ack = \m n ->
  case Tuple m n of
    Tuple 0 n -> n + 1
    Tuple m 0 -> ack (m - 1) 1
    Tuple m n -> ack (m - 1) (ack m (n - 1))

main = ack 4 0


================================================
FILE: examples/arith.hs
================================================
main = 22.0 + 33.0


================================================
FILE: examples/bound.hs
================================================
class Bounded a where
  minBound :: a
  maxBound :: a
instance Bounded Bool where
  minBound = False
  maxBound = True
data Tuple a a = Tuple a a
main = Tuple True minBound


================================================
FILE: examples/builtins.hs
================================================
data X = X Integer Char Rational String
class Show a where show :: a -> String
instance Show Integer where show = \_ -> "a"
foo :: X -> Integer
foo = \x -> 123


================================================
FILE: examples/classes.hs
================================================
class Reader a where
  reader :: List Ch -> a
class Shower a where
  shower :: a -> List Ch
instance Shower Nat where
  shower = \n ->
    case n of
      Zero -> Cons Z Nil
      Succ n -> Cons S (shower n)
data Nat = Succ Nat | Zero
instance Reader Nat where
  reader = \cs ->
    case cs of
      Cons Z Nil -> Zero
      Cons S xs  -> Succ (reader xs)
      _ -> Zero
data List a = Nil | Cons a (List a)
data Ch = A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | Q | R | S | T | U | V | W | X | Y | Z
class Equal a where
  equal :: a -> a -> Bool
instance Equal Nat where
  equal =
    \a b ->
      case a of
        Zero ->
          case b of
            Zero -> True
            _ -> False
        Succ n ->
          case b of
            Succ m -> equal n m
            _ -> False
        _ -> False
not = \b -> case b of
              True -> False
              False -> True

notEqual :: Equal a => a -> a -> Bool
notEqual = \x y -> not (equal x y)

main = equal (reader (shower (Succ Zero))) (Succ Zero)


================================================
FILE: examples/fac.hs
================================================
factorial = \n -> case n of
                    0 -> 1
                    1 -> 1
                    _ -> n * factorial (n - 1)


go =
  \n acc0 ->
    case acc0 of
      acc ->
        case n of
          0 -> acc
          1 -> acc
          _ -> go (n - 1) (n * acc)

go_efficient =
  \n acc0 ->
    case acc0 of
      !acc ->
        case n of
          0 -> acc
          1 -> acc
          nf -> go_efficient (nf - 1) (nf * acc)

it = go 5 1

it_efficient = go_efficient 5 1


================================================
FILE: examples/factorial.hs
================================================
data N = S N | Z | M N N
sub = \n -> case n of
              S c -> c
fac = \n -> case n of
              Z -> S Z
              _ -> M n (fac (sub n))

facAcc = \a n ->
  case n of
    Z -> a
    _ -> facAcc (M n a) (sub n)

facA = facAcc (S Z)
id = \x -> x
main = fac (S (S Z))


================================================
FILE: examples/folds-strictness.hs
================================================
data List a = Nil | Cons a (List a)
foldr = \f z l ->
  case l of
    Nil -> z
    Cons x xs -> f x (foldr f z xs)
foldl = \f z l ->
  case l of
    Nil -> z
    Cons x xs -> foldl f (f z x) xs
foldl_ = \f z l ->
  case l of
    Nil -> z
    Cons x xs ->
      case f z x of
        !z_ -> foldl_ f z_ xs
list = (Cons 1 (Cons 2 (Cons 3 (Cons 4 Nil))))
main_foldr = foldr _f _nil list
main_foldl = foldl _f _nil list
main_foldl_ = foldl_ (\x y -> x + y) 0 list


================================================
FILE: examples/folds.hs
================================================
data List a = Nil | Cons a (List a)
foldr = \f z l ->
  case l of
    Nil -> z
    Cons x xs -> f x (foldr f z xs)
foldl = \f z l ->
  case l of
    Nil -> z
    Cons x xs -> foldl f (f z x) xs
list = (Cons True (Cons False Nil))

main_foldr = foldr _f _nil list
main_foldl = foldl _f _nil list


================================================
FILE: examples/functor-class.hs
================================================
data Maybe a = Nothing | Just a
class Functor (f :: Type -> Type) where
  map :: (a -> b) -> f a -> f b
instance Functor Maybe where
  map = \f m ->
    case m of
      Nothing -> Nothing
      Just a -> Just (f a)
not = \b -> case b of
              True -> False
              False -> True
main = map (\x -> x) (Just 123)


================================================
FILE: examples/gabriel-eq-reason.hs
================================================
data IO a = Print Nat (IO a) | Return a

data Nat = Z | S Nat

data List a = Nil | Cons a (List a)

data Unit = Unit

bind =
  \m f ->
    case m of
      Return a -> f a
      Print bool m1 -> Print bool (bind m1 f)

next = \m n ->
  bind m (\_ -> n)

print = \x -> Print x (Return Unit)

return = Return

repeat = \x -> Cons x (repeat x)

foldr = \cons nil l ->
  case l of
    Nil -> nil
    Cons x xs -> cons x (foldr cons nil xs)

sequence_ = \ms -> foldr next (return Unit) ms

take =
  \n l ->
    case n of
      Z -> Nil
      S m ->
        case l of
          Nil -> Nil
          Cons x xs -> Cons x (take m xs)

replicate = \n x -> take n (repeat x)

replicateM_ = \n m -> sequence_ (replicate n m)

main = replicateM_ (S (S (S (S (S (S Z)))))) (print (S Z))


================================================
FILE: examples/good.hs
================================================
class Good a where
  good :: a -> Bool
data Maybe a = Just a | Nothing
instance Good Bool where
  good = \x -> x
instance Good a => Good (Maybe a) where
  good = \x ->
    case x of
      Nothing -> False
      Just a -> good a
main = good (Just True)


================================================
FILE: examples/integers.hs
================================================
main = 3 + ((2 + -3) - 3)


================================================
FILE: examples/lists.hs
================================================
data List a = Nil | Cons a (List a)
map = \f xs ->
  case xs of
    Nil -> Nil
    Cons x xs -> Cons (f x) (map f xs)
list = (Cons 1 (Cons 2 Nil))
multiply = \x y -> x * y
doubleAll = \xs -> map (multiply 2) xs
main = doubleAll list


================================================
FILE: examples/monad.hs
================================================
class Monad (m :: Type -> Type) where
  bind :: m a -> (a -> m b) -> m b
class Applicative (f :: Type -> Type) where
  pure :: a -> f a
  ap :: f (a -> b) -> f a -> f b
class Functor (f :: Type -> Type) where
  map :: (a -> b) -> f a -> f b
data Maybe a = Nothing | Just a
instance Functor Maybe where
  map =
    \f m ->
      case m of
        Nothing -> Nothing
        Just a -> Just (f a)
instance Monad Maybe where
  bind =
    \m f ->
      case m of
        Nothing -> Nothing
        Just v -> f v
instance Applicative Maybe where
  pure = \v -> Just v
  ap = \a b -> Nothing


================================================
FILE: examples/monoid.hs
================================================
class Monoid a where
  mempty  :: a
  mappend :: a -> a -> a
data List a = Nil | Cons a (List a)
instance Monoid (List a) where
  mempty = Nil
  mappend = \x y ->
    case x of
      Cons a xs -> Cons a (mappend xs y)
      Nil -> y
main = mappend (Cons 'a' (Cons 'b' Nil)) (Cons 'c' (Cons 'd' Nil))


================================================
FILE: examples/ord.hs
================================================
class Ord a  where
  compare :: a -> a -> Ordering
data Ordering
  = EQ
  | LT
  | GT
instance Ord Ordering where
  compare =
    \x y ->
      case x of
        LT ->
          case y of
            LT -> EQ
            EQ -> LT
            GT -> LT
        EQ ->
          case y of
            LT -> GT
            EQ -> EQ
            GT -> LT
        GT ->
          case y of
            LT -> GT
            EQ -> GT
            GT -> EQ
main = compare EQ LT


================================================
FILE: examples/parser.hs
================================================
data Tuple a b = Tuple a b
data Result a = OK a String | Error String
data Parser a = Parser (String -> Result a)
parseBool =
  Parser
    (\string ->
       case take 4 string of
         "True" ->
           case drop 4 string of
             !rest -> OK True rest
         _ ->
           case take 5 string of
             "False" ->
               case drop 5 string of
                 !rest -> OK False rest
             _ -> Error (append "Expected a bool, but got: " string))
runParser =
  \p s ->
    case p of
      Parser f -> f s
bind =
  \m f ->
    Parser
      (\s ->
         case runParser m s of
           OK a rest -> runParser (f a) rest
           Error err -> Error err)
pure = \a -> Parser (OK a)
main = runParser (bind parseBool (\x -> bind parseBool (\y -> pure (Tuple x y)))) "TrueFalse"


================================================
FILE: examples/pattern-matching.hs
================================================
data Uk = Manchester | Bristol

data Italy = Trento | Padova

data Europe = Uk Uk | Italy Italy

bristol = Bristol

main = case Uk bristol of
         Uk Manchester -> "uk-manc"
         Uk Bristol -> "uk-bristol"
         Italy Trento -> "italy-trento"
         Italy Padova -> "italy-padova"


================================================
FILE: examples/placeholders.hs
================================================
data List a = Nil | Cons a (List a)
foldr = \f z l ->
  case l of
    Nil -> z
    Cons x xs -> f x (foldr f z xs)
foldl = \f z l ->
  case l of
    Nil -> z
    Cons x xs -> foldl f (f z x) xs
list = (Cons True (Cons False Nil))
main = foldr _f _nil list


================================================
FILE: examples/prelude.hs
================================================
data Bool = True | False

data Ordering = EQ | LT | GT

class Eq a where
  equal :: a -> a -> Bool
  notEqual :: a -> a -> Bool

class Ord a where
  compare :: a -> a -> Ordering

class Monad (m :: Type -> Type) where
  bind :: m a -> (a -> m b) -> m b

class Applicative (f :: Type -> Type) where
  pure :: a -> f a
  ap :: f (a -> b) -> f a -> f b

class Functor (f :: Type -> Type) where
  map :: (a -> b) -> f a -> f b

class Num a where
  plus :: a -> a -> a
  times :: a -> a -> a

class Neg a where
  negate :: a -> a
  abs :: a -> a
  subtract :: a -> a -> a

class MinBound b where
  minBound :: b

class MaxBound b where
  maxBound :: b

class Integral a where
  div :: a -> a -> a
  mod :: a -> a -> a

class Fractional a where
  divide :: a -> a -> a
  recip :: a -> a


================================================
FILE: examples/seq.hs
================================================
seq :: a -> b -> b
seq =
  \x y ->
    case x of
      !_ -> y
loop = loop
main = seq loop 1


================================================
FILE: examples/sicp.hs
================================================
square = \x -> x * x
it = square 6 + square 10


================================================
FILE: examples/simple-class.hs
================================================
class X a where
 f :: a -> D
data D = D | C
instance X D where
 f = \x -> case x of
             D -> D
             C -> f D

main = f C


================================================
FILE: examples/state.hs
================================================
data Unit = Unit
class Monad (m :: Type -> Type) where
  bind :: m a -> (a -> m b) -> m b
class Applicative (f :: Type -> Type) where
  pure :: a -> f a
class Functor (f :: Type -> Type) where
  map :: (a -> b) -> f a -> f b
data Result s a = Result s a
data State s a = State (s -> Result s a)
instance Functor (State s) where
  map =
    \f state ->
      case state of
        State s2r ->
          State
            (\s ->
               case s2r s of
                 Result s1 a -> Result s1 (f a))
instance Monad (State s) where
  bind =
    \m f ->
      case m of
        State s2r ->
          State
            (\s ->
               case s2r s of
                 Result s a ->
                   case f a of
                     State s2r1 -> s2r1 s)
instance Applicative (State s) where
  pure = \a -> State (\s -> Result s a)
runState =
  \m a ->
    case m of
      State f -> f a
get = State (\s -> Result s s)
put = \s -> State (\k -> Result s Unit)
next = \m n -> bind m (\_ -> n)
main = runState (next (put False) (pure Unit)) True


================================================
FILE: examples/strict-folds.hs
================================================
data List a = Nil | Cons a (List a)
foldr = \f z l ->
  case l of
    Nil -> z
    Cons x xs -> f x (foldr f z xs)
foldl = \f z l ->
  case l of
    Nil -> z
    Cons x xs -> foldl f (f z x) xs
list = (Cons 1 (Cons 2 (Cons 3 (Cons 4 Nil))))
main_foldr = foldr (+) _nil list
main_foldl = foldl (+) _nil list


================================================
FILE: examples/string-pats.hs
================================================
main =
  case "foo" of
    "bar" -> 0
    "foo" -> 1


================================================
FILE: examples/string-substring.hs
================================================
main = append (take 2 (drop 7 "Hello, World!")) "!"


================================================
FILE: examples/syntax-buffet.hs
================================================
class Reader a where
  reader :: List Ch -> a
class Shower a where
  shower :: a -> List Ch
instance Shower Nat where
  shower = \n ->
    case n of
      Zero -> Cons Z Nil
      Succ n -> Cons S (shower n)
data Nat = Succ Nat | Zero
instance Reader Nat where
  reader = \cs ->
    case cs of
      Cons Z Nil -> Zero
      Cons S xs  -> Succ (reader xs)
      _ -> Zero
data List a = Nil | Cons a (List a)
data Ch = A | B | C | D | E | F | G | H | I | J | K | L | M | N | O | P | Q | R | S | T | U | V | W | X | Y | Z
class Equal a where
  equal :: a -> a -> Bool
instance Equal Nat where
  equal =
    \a b ->
      case a of
        Zero ->
          case b of
            Zero -> True
            _ -> False
        Succ n ->
          case b of
            Succ m -> equal n m
            _ -> False
        _ -> False
not = \b -> case b of
              True -> False
              False -> True
notEqual :: Equal a => a -> a -> Bool
notEqual = \x y -> not (equal x y)
main = if not False
          then equal (reader (shower (Succ Zero))) (Succ Zero)
          else False


================================================
FILE: examples/terminal.hs
================================================
data Terminal a
 = GetLine (String -> Terminal a)
 | PutStrLn String (Terminal a)
 | Pure a

main =
  PutStrLn
    "Please enter your name: "
    (GetLine (\line -> PutStrLn (append "Hello, " line) (Pure 0)))


================================================
FILE: src/Control/Monad/Supply.hs
================================================
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-- | Support for computations which consume values from a (possibly infinite)
-- supply. See <http://www.haskell.org/haskellwiki/New_monads/MonadSupply> for
-- details.
--
-- Patched to provide MonadCatch/MonadThrow instead of MonadError.
--
module Control.Monad.Supply
( MonadSupply (..)
, SupplyT
, Supply
, evalSupplyT
, evalSupply
, runSupplyT
, runSupply
) where

import Control.Monad.Catch
import Control.Monad.Identity
#ifndef __GHCJS__
import Control.Monad.Logger
#endif
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Writer

class Monad m => MonadSupply s m | m -> s where
  supply :: m s
  peek :: m s
  exhausted :: m Bool

-- | Supply monad transformer.
newtype SupplyT s m a = SupplyT (StateT [s] m a)
#ifdef __GHCJS__
  deriving (Functor, Applicative, Monad, MonadTrans, MonadIO, MonadFix, MonadCatch, MonadThrow)
#else
  deriving (Functor, Applicative, Monad, MonadTrans, MonadIO, MonadFix, MonadCatch, MonadThrow, MonadLogger)
#endif
-- | Supply monad.
newtype Supply s a = Supply (SupplyT s Identity a)
  deriving (Functor, Applicative, Monad, MonadSupply s, MonadFix)

instance Monad m => MonadSupply s (SupplyT s m) where
  supply =
    SupplyT $ do
      result <- get
      case result of
        (x:xs) -> do
          put xs
          return x
        _ -> error "Exhausted supply in Control.Monad.Supply.hs"
  peek = SupplyT $ gets head
  exhausted = SupplyT $ gets null

instance MonadSupply s m => MonadSupply s (StateT st m) where
  supply = lift supply
  peek = lift peek
  exhausted = lift exhausted

instance MonadSupply s m => MonadSupply s (ReaderT r m) where
  supply = lift supply
  peek = lift peek
  exhausted = lift exhausted

instance (Monoid w, MonadSupply s m) => MonadSupply s (WriterT w m) where
  supply = lift supply
  peek = lift peek
  exhausted = lift exhausted

evalSupplyT :: Monad m => SupplyT s m a -> [s] -> m a
evalSupplyT (SupplyT s) = evalStateT s

evalSupply :: Supply s a -> [s] -> a
evalSupply (Supply s) = runIdentity . evalSupplyT s

runSupplyT :: Monad m => SupplyT s m a -> [s] -> m (a,[s])
runSupplyT (SupplyT s) = runStateT s

runSupply :: Supply s a -> [s] -> (a,[s])
runSupply (Supply s) = runIdentity . runSupplyT s


================================================
FILE: src/Duet/Context.hs
================================================
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleContexts #-}

-- | Functions for setting up the context.

module Duet.Context where

import           Control.Monad
import           Control.Monad.Catch
import           Control.Monad.Supply
import qualified Data.Map.Strict as M
import           Data.Maybe
import           Duet.Infer
import           Duet.Renamer
import           Duet.Supply
import           Duet.Types

-- | Make an instance.
makeInst
  :: MonadSupply Int m
  => Specials Name
  -> Predicate Type Name
  -> [(String, (l, Alternative Type Name l))]
  -> m (Instance Type Name l)
makeInst specials pred' methods = do
  name <- supplyDictName (predicateToDict specials pred')
  methods' <-
    mapM
      (\(key, alt) -> do
         key' <- supplyMethodName (Identifier key)
         pure (key', alt))
      methods
  pure (Instance (Forall [] (Qualified [] pred')) (Dictionary name (M.fromList methods')))

-- | Make a class.
makeClass
  :: MonadSupply Int m
  => Identifier
  -> [TypeVariable Name]
  -> [(Name, Scheme t Name t)]
  -> m (Class t Name l)
makeClass name vars methods = do
  name' <- supplyClassName name
  pure
    (Class
     { className = name'
     , classTypeVariables = vars
     , classInstances = []
     , classMethods = M.fromList methods
     , classSuperclasses = mempty
     })

-- | Generate signatures from a data type.
dataTypeSignatures
  :: Monad m
  => SpecialTypes Name -> DataType Type Name -> m [TypeSignature Type Name Name]
dataTypeSignatures specialTypes dt@(DataType _ vs cs) = mapM construct cs
  where
    construct (DataTypeConstructor cname fs) =
      pure
        (TypeSignature
           cname
           (Forall
              vs
              (Qualified
                 []
                 (foldr
                    makeArrow
                    (foldl
                       ApplicationType
                       (dataTypeConstructor dt)
                       (map VariableType vs))
                    fs))))
      where
        makeArrow :: Type Name -> Type Name -> Type Name
        a `makeArrow` b =
          ApplicationType
            (ApplicationType
               (ConstructorType (specialTypesFunction specialTypes))
               a)
            b

-- | Make signatures from a class.
classSignatures
  :: MonadThrow m
  => Class Type Name l -> m [TypeSignature Type Name Name]
classSignatures cls =
  mapM
    (\(name, scheme) ->
       TypeSignature <$> pure name <*> classMethodScheme cls scheme)
    (M.toList (classMethods cls))

builtinsSpecials :: Builtins t i l -> Specials i
builtinsSpecials builtins =
  Specials (builtinsSpecialSigs builtins) (builtinsSpecialTypes builtins)

contextSpecials :: Context t i l -> Specials i
contextSpecials context =
  Specials (contextSpecialSigs context) (contextSpecialTypes context)

generateAllSignatures :: (MonadThrow m, Traversable t, Traversable t1) => Builtins Type Name l1 -> t1 (DataType Type Name) -> t (Class Type Name l) -> m [TypeSignature Type Name Name]
generateAllSignatures builtins dataTypes typeClasses =
  do consSigs <-
       fmap
         concat
         (mapM (dataTypeSignatures (builtinsSpecialTypes builtins)) dataTypes)
     methodSigs <- fmap concat (mapM classSignatures typeClasses)
     pure (builtinsSignatures builtins <> consSigs <> methodSigs)

makeScope :: Applicative f => M.Map Identifier (Class t2 Name l) -> [TypeSignature t1 t Name] -> f (M.Map Identifier Name)
makeScope typeClasses signatures =
  pure
    (M.fromList
       (mapMaybe
          (\(TypeSignature name _) ->
             case name of
               ValueName _ ident -> Just (Identifier ident, name)
               ConstructorName _ ident -> pure (Identifier ident, name)
               MethodName _ ident -> pure (Identifier ident, name)
               _ -> Nothing)
          signatures) <>
     M.map className typeClasses)

renameEverything ::
     (MonadThrow m, MonadSupply Int m)
  => [Decl UnkindedType Identifier Location]
  -> Specials Name
  -> Builtins Type Name Location
  -> m ( M.Map Identifier (Class Type Name Location)
       , [TypeSignature Type Name Name]
       , [Binding Type Name Location]
       , M.Map Identifier Name
       , [DataType Type Name])
renameEverything decls specials builtins = do
  dataTypes <- renameDataTypes specials (declsDataTypes decls)
  (typeClasses, signatures, subs) <-
    do typeClasses <-
         fmap
           M.fromList
           (mapM
              (\c -> do
                 renamed <- renameClass specials mempty dataTypes c
                 pure (className c, renamed))
              classes)
       signatures <- generateAllSignatures builtins dataTypes typeClasses
       scope <- makeScope typeClasses signatures
       allInstances <-
         mapM
           (renameInstance specials scope dataTypes (M.elems typeClasses))
           instances
       pure
         ( M.map
             (\typeClass ->
                typeClass
                { classInstances =
                    filter
                      ((== className typeClass) . instanceClassName)
                      allInstances
                })
             typeClasses
         , signatures
         , scope)
  (renamedBindings, subs') <- renameBindings specials subs dataTypes bindings
  pure (typeClasses, signatures, renamedBindings, subs', dataTypes)
  where declsDataTypes =
          mapMaybe
            (\case
               DataDecl _ d -> Just d
               _ -> Nothing)
        bindings =
          mapMaybe
            (\case
               BindDecl _ d -> Just d
               _ -> Nothing)
            decls
        classes =
          mapMaybe
            (\case
               ClassDecl _ d -> Just d
               _ -> Nothing)
            decls
        instances =
          mapMaybe
            (\case
               InstanceDecl _ d -> Just d
               _ -> Nothing)
            decls

addClasses :: (MonadThrow m, Foldable t) => Builtins Type Name l -> t (Class Type Name l) -> m (M.Map Name (Class Type Name l))
addClasses builtins typeClasses =
  foldM
    (\e0 typeClass ->
       addClass typeClass e0 >>= \e ->
         foldM (\e1 i -> do addInstance i e1) e (classInstances typeClass))
    (builtinsTypeClasses builtins)
    typeClasses


================================================
FILE: src/Duet/Errors.hs
================================================
{-# LANGUAGE LambdaCase #-}

-- |

module Duet.Errors where

import           Control.Exception
import           Data.Char
import           Data.Function
import           Data.List
import qualified Data.Map.Strict as M
import           Data.Ord
import           Data.Typeable
import           Duet.Printer
import           Duet.Types
import           Text.EditDistance

displayContextException :: ContextException -> String
displayContextException (ContextException specialTypes (SomeException se)) =
  maybe
    (maybe
       (maybe
          (maybe
             (maybe
                (displayException se)
                (displayRenamerException specialTypes)
                (cast se))
             (displayInferException specialTypes)
             (cast se))
          (displayStepperException specialTypes)
          (cast se))
       (displayResolveException specialTypes)
       (cast se))
    displayParseException
    (cast se)

displayParseException :: ParseException -> String
displayParseException e =
  case e of
    TokenizerError pe -> show pe
    ParserError pe -> show pe

displayResolveException :: SpecialTypes Name -> ResolveException -> String
displayResolveException specialTypes =
  \case
    NoInstanceFor p -> "No instance for " ++ printPredicate defaultPrint specialTypes p

displayStepperException :: a -> StepException -> String
displayStepperException _ =
  \case
    CouldntFindName n ->
      "Not in scope: " ++ curlyQuotes (printit defaultPrint n)
    CouldntFindMethodDict n ->
      "No instance dictionary for: " ++ curlyQuotes (printit defaultPrint n)
    CouldntFindNameByString n ->
      "The starter variable isn't defined: " ++
      curlyQuotes n ++ "\nPlease define a variable called " ++ curlyQuotes n
    TypeAtValueScope k -> "Type at value scope: " ++ show k

displayInferException :: SpecialTypes Name -> InferException -> [Char]
displayInferException specialTypes =
  \case
    ExplicitTypeMismatch sc1 sc2 ->
      "The type of a definition, \n\n  " ++
      printScheme defaultPrint specialTypes sc2 ++ "\n\ndoesn't match the explicit type:\n\n  " ++
     printScheme defaultPrint specialTypes sc1
    NotInScope scope name ->
      "Not in scope " ++
      curlyQuotes (printit defaultPrint name) ++
      "\n" ++
      "Nearest names in scope:\n\n" ++
      intercalate
        ", "
        (map
           curlyQuotes
           (take
              5
              (sortBy
                 (comparing (editDistance (printit defaultPrint name)))
                 (map (printTypeSignature defaultPrint specialTypes) scope))))
    TypeMismatch t1 t2 ->
      "Couldn't match type " ++
      curlyQuotes (printType defaultPrint specialTypes t1) ++
      "\n" ++
      "against inferred type " ++ curlyQuotes (printType defaultPrint specialTypes t2)
    OccursCheckFails ->
      "Infinite type (occurs check failed). \nYou \
                        \probably have a self-referential value!"
    AmbiguousInstance ambiguities ->
      "Couldn't infer which instances to use for\n" ++
      unlines
        (map
           (\(Ambiguity _ ps) ->
              intercalate ", " (map (printPredicate defaultPrint specialTypes) ps))
           ambiguities)
    e -> show e

displayRenamerException :: SpecialTypes Name -> RenamerException -> [Char]
displayRenamerException specialTypes =
  wrap (\case
          IdentifierNotInVarScope scope name label ->
            "Not in variable scope " ++
            curlyQuotes (printit defaultPrint name) ++
            -- " (AST tree label: "++show label ++")"++
            "\n" ++
            "Nearest names in scope:\n\n" ++
            intercalate
              ", "
              (map
                 curlyQuotes
                 (take
                    5
                    (sortBy
                       (comparing (editDistance (printit defaultPrint name)))
                       (map (printit defaultPrint) (M.elems scope)))))
          IdentifierNotInConScope scope name ->
            "Not in constructors scope " ++
            curlyQuotes (printit defaultPrint name) ++
            "\n" ++
            "Nearest names in scope:\n\n" ++
            intercalate
              ", "
              (map
                 curlyQuotes
                 (take
                    5
                    (sortBy
                       (comparing (editDistance (printit defaultPrint name)))
                       (map (printit defaultPrint) (M.elems scope)))))
          KindTooManyArgs ty k ty2 ->
            "The type " ++
            curlyQuotes (printType defaultPrint specialTypes ty ++ " :: " ++ printKind k) ++
            " has an unexpected additional argument, " ++
            curlyQuotes (printType defaultPrint specialTypes ty2)
          ConstructorFieldKind cons typ kind ->
            "The type " ++
            curlyQuotes (printType defaultPrint specialTypes typ ++ " :: " ++ printKind kind) ++
            " is used in a field in the " ++
            curlyQuotes (printit defaultPrint cons) ++
            " constructor, but all fields \
            \should have types of kind " ++
            curlyQuotes (printKind StarKind)
          KindArgMismatch t1 k1 t2 k2 ->
            "The type " ++
            curlyQuotes (printType defaultPrint specialTypes t1 ++ " :: " ++ printKind k1) ++
            " has been given an argument of the wrong kind " ++
            curlyQuotes (printType defaultPrint specialTypes t2 ++ " :: " ++ printKind k2)
          TypeNotInScope types i ->
            "Unknown type " ++
            curlyQuotes (printIdentifier defaultPrint i) ++
            "\n" ++
            "Closest names in scope are: " ++
            intercalate
              ", "
              (map
                 curlyQuotes
                 (take
                    5
                    (sortBy
                       (comparing (editDistance (printIdentifier defaultPrint i)))
                       (map (printTypeConstructor defaultPrint) types))))
          UnknownTypeVariable types i ->
            "Unknown type variable " ++
            curlyQuotes (printIdentifier defaultPrint i) ++
            "\n" ++
            "Type variables in scope are: " ++
            intercalate
              ", "
              (map
                 curlyQuotes
                 (sortBy
                    (comparing (editDistance (printIdentifier defaultPrint i)))
                    (map (printTypeVariable defaultPrint) types)))
          e -> show e)
  where wrap f e = (f e)-- ++ "\n(" ++ show e ++ ")"

editDistance :: [Char] -> [Char] -> Int
editDistance = on (levenshteinDistance defaultEditCosts) (map toLower)


================================================
FILE: src/Duet/Infer.hs
================================================
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

-- | A clear-to-read, well-documented, implementation of a Haskell 98
-- type checker adapted from Typing Haskell In Haskell, by Mark
-- P. Jones.

module Duet.Infer
  (
  -- * Type checker
  -- $type-checker
    typeCheckModule
  , byInst
  , InferException(..)
  -- * Setting up
  , addClass
  , addInstance
  , SpecialTypes(..)
  , ReadException(..)
  -- * Printers
  -- , printTypeSignature
  -- * Types syntax tree
  , Type(..)
  , Kind(..)
  , Scheme(..)
  , TypeSignature(..)
  , TypeVariable(..)
  , Qualified(..)
  , Class(..)
  , Predicate(..)
  , TypeConstructor(..)
  -- * Values syntax tree
  , ImplicitlyTypedBinding(..)
  , ExplicitlyTypedBinding(..)
  , Expression(..)
  , Literal(..)
  , Pattern(..)
  , BindGroup(..)
  , Alternative(..)
  , toScheme
  , typeKind
  , classMethodScheme
  ) where

import           Control.Arrow (first,second)
import           Control.Monad.Catch
import           Control.Monad.State
import           Data.Generics
import           Data.Graph
import           Data.List
import           Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import           Data.Maybe
import           Duet.Types

--------------------------------------------------------------------------------
-- Type inference
--

-- $type-checker
--
-- The type checker takes a module and produces a list of type
-- signatures. It checks that all types unify, and infers the types of
-- unannotated expressions. It resolves type-class instances.

-- | Type check the given module and produce a list of type
-- signatures.
--
-- >>> fmap (map printTypeSignature) (typeCheckModule mempty [] [BindGroup [] [[ImplicitlyTypedBinding (Identifier "id") [Alternative [VariablePattern (Identifier "x")] (VariableExpression (Identifier "x"))]]]])
-- ["id :: forall a0. a0 -> a0"]
--
-- Throws 'InferException' in case of a type error.
typeCheckModule ::
     (MonadThrow m)
  => Map Name (Class Type Name Location) -- ^ Set of defined type-classes.
  -> [(TypeSignature Type Name Name)] -- ^ Pre-defined type signatures e.g. for built-ins or FFI.
  -> SpecialTypes Name -- ^ Special types that Haskell uses for pattern matching and literals.
  -> [Binding Type Name Location] -- ^ Bindings in the module.
  -> m ( [BindGroup Type Name (TypeSignature Type Name Location)]
       , Map Name (Class Type Name (TypeSignature Type Name Location)))
typeCheckModule ce as specialTypes bgs0 = do
  (bgs, classes) <- runTypeChecker (dependencyAnalysis bgs0)
  pure (bgs, classes)
  where
    runTypeChecker bgs =
      evalStateT
        (runInferT $ do
           instanceBgs <- classMethodsToGroups specialTypes ce
           (ps, _, bgs') <-
             inferSequenceTypes inferBindGroupTypes ce as (bgs ++ instanceBgs)
           s <- InferT (gets inferStateSubstitutions)
           let rs = reduce ce (map (substitutePredicate s) ps)
           s' <- defaultSubst ce [] rs
           let bgsFinal = map (fmap (substituteTypeSignature (s' @@ s))) bgs'
           ce' <- collectMethods bgsFinal ce
           return (bgsFinal, ce'))
        (InferState nullSubst 0 specialTypes)

-- | Sort the list of bindings by order of no-dependencies first
-- followed by things that depend on them. Group bindings that are
-- mutually recursive.
dependencyAnalysis :: Data l => [Binding Type Name l] -> [BindGroup Type Name l]
dependencyAnalysis = map toBindGroup . stronglyConnComp . bindingsGraph
  where
    toBindGroup =
      \case
        AcyclicSCC binding ->
          BindGroup (explicits [binding]) [implicits [binding]]
        CyclicSCC bindings ->
          BindGroup (explicits bindings) [implicits bindings]
    explicits =
      mapMaybe
        (\case
           ExplicitBinding i -> Just i
           _ -> Nothing)
    implicits =
      mapMaybe
        (\case
           ImplicitBinding i -> Just i
           _ -> Nothing)

-- | Make a graph of the bindings with their dependencies.
bindingsGraph :: Data l => [Binding Type Name l] -> [(Binding Type Name l, Name, [Name])]
bindingsGraph =
  map
    (\binding ->
       ( binding
       , bindingIdentifier binding
       , listify
           (\case
              n@ValueName {} -> n /= bindingIdentifier binding
              _ -> False)
           (bindingAlternatives binding)))

collectMethods ::
     forall l m. MonadThrow m
  => [BindGroup Type Name (TypeSignature Type Name l)]
  -> Map Name (Class Type Name l)
  -> m (Map Name (Class Type Name (TypeSignature Type Name l)))
collectMethods binds =
  fmap M.fromList .
  mapM
    (\(name, cls) -> do
       insts <-
         mapM
           (\inst -> do
              methods <-
                mapM
                  collectMethod
                  (M.toList (dictionaryMethods (instanceDictionary inst)))
              pure
                inst
                { instanceDictionary =
                    (instanceDictionary inst)
                    {dictionaryMethods = M.fromList methods}
                })
           (classInstances cls)
       pure (name, cls {classInstances = insts})) .
  M.toList
  where
    collectMethod ::
         (Name, (l, t))
      -> m ( Name
           , ( TypeSignature Type Name l
             , Alternative Type Name (TypeSignature Type Name l)))
    collectMethod (key, (l, _)) =
      case listToMaybe
             (mapMaybe
                (\(BindGroup ex _) ->
                   listToMaybe
                     (mapMaybe
                        (\i ->
                           if fst (explicitlyTypedBindingId i) == key
                             then listToMaybe
                                    (explicitlyTypedBindingAlternatives i)
                             else Nothing)
                        ex))
                binds) of
        Just alt ->
          pure
            ( key
            , ( TypeSignature l (typeSignatureScheme (alternativeLabel alt))
              , alt))
        Nothing -> throwM MissingMethod

classMethodsToGroups
  :: MonadThrow m
  => SpecialTypes Name -> Map Name (Class Type Name l) -> m [BindGroup Type Name l]
classMethodsToGroups specialTypes =
  mapM
    (\class' ->
       BindGroup <$>
       fmap
         concat
         (mapM
            (\inst ->
               sequence
                 (zipWith
                    (\methodScheme (instMethodName, (l, methodAlt)) ->
                       ExplicitlyTypedBinding <$> pure l <*>
                       pure (instMethodName, l) <*>
                       instanceMethodScheme
                         specialTypes
                         class'
                         methodScheme
                         (instancePredicate inst) <*>
                       pure [methodAlt])
                    (M.elems (classMethods class'))
                    (M.toList (dictionaryMethods (instanceDictionary inst)))))
            (classInstances class')) <*>
       pure []) .
  M.elems

instanceMethodScheme
  :: MonadThrow m
  => SpecialTypes Name
  -> Class Type Name l
  -> Scheme Type Name Type
  -> Scheme Type Name (Predicate Type)
  -> m (Scheme Type Name Type)
instanceMethodScheme _specialTypes cls (Forall methodVars0 (Qualified methodPreds methodType0)) _instScheme@(Forall instanceVars0 (Qualified preds (IsIn _ headTypes))) = do
  methodQual <- instantiateQual (Qualified (methodPreds ++ preds) methodType0)
  pure (Forall methodVars methodQual)
  where
    methodVars = filter (not . flip elem (classTypeVariables cls)) (methodVars0 ++ instanceVars0)
    table = zip (classTypeVariables cls) headTypes
    instantiateQual (Qualified ps t) =
      Qualified <$> mapM instantiatePred ps <*> instantiate t
    instantiatePred (IsIn c t) = IsIn c <$> mapM instantiate t
    instantiate =
      \case
        ty@(VariableType tyVar) ->
          case lookup tyVar table of
            Nothing -> pure ty
            Just typ -> pure typ
        ApplicationType a b ->
          ApplicationType <$> instantiate a <*> instantiate b
        typ -> pure typ

classMethodScheme
  :: MonadThrow m
  => Class t Name l -> Scheme Type Name Type -> m (Scheme Type Name Type)
classMethodScheme cls (Forall methodVars (Qualified methodPreds methodType)) = do
  ty' <- pure methodType
  headVars <- mapM (pure . VariableType) (classTypeVariables cls)
  pure
    (Forall
       methodVars
       (Qualified (methodPreds ++ [IsIn (className cls) headVars]) ty'))

--------------------------------------------------------------------------------
-- Substitution

infixr 4 @@
(@@) :: [Substitution Name] -> [Substitution Name] -> [Substitution Name]
s1 @@ s2 = [Substitution u (substituteType s1 t) | (Substitution u t) <- s2] ++ s1

nullSubst :: [Substitution Name]
nullSubst = []

substituteQualified :: [Substitution Name] -> Qualified Type Name (Type Name) -> Qualified Type Name (Type Name)
substituteQualified substitutions (Qualified predicates t) =
  Qualified
    (map (substitutePredicate substitutions) predicates)
    (substituteType substitutions t)

substituteTypeSignature :: [Substitution Name] -> (TypeSignature Type Name l) -> (TypeSignature Type Name l)
substituteTypeSignature substitutions (TypeSignature l scheme) =
    TypeSignature l (substituteInScheme substitutions scheme)
  where substituteInScheme subs' (Forall kinds qualified) =
          Forall kinds (substituteQualified subs' qualified)

substitutePredicate :: [Substitution Name] -> Predicate Type Name -> Predicate Type Name
substitutePredicate substitutions (IsIn identifier types) =
    IsIn identifier (map (substituteType substitutions) types)

substituteType :: [Substitution Name] -> Type Name -> Type Name
substituteType substitutions (VariableType typeVariable) =
    case find ((== typeVariable) . substitutionTypeVariable) substitutions of
      Just substitution -> substitutionType substitution
      Nothing -> VariableType typeVariable
substituteType substitutions (ApplicationType type1 type2) =
    ApplicationType
      (substituteType substitutions type1)
      (substituteType substitutions type2)
substituteType _ typ = typ

--------------------------------------------------------------------------------
-- Type inference

unify :: MonadThrow m => Type Name -> Type Name -> InferT m ()
unify t1 t2 = do
  s <- InferT (gets inferStateSubstitutions)
  u <- unifyTypes (substituteType s t1) (substituteType s t2)
  InferT
    (modify
       (\s' -> s' {inferStateSubstitutions = u @@ inferStateSubstitutions s'}))

newVariableType :: Monad m => Kind -> InferT m (Type Name)
newVariableType k =
  InferT
    (do inferState <- get
        put inferState {inferStateCounter = inferStateCounter inferState + 1}
        return
          (VariableType (TypeVariable (enumId (inferStateCounter inferState)) k)))

inferExplicitlyTypedBindingType
  :: (MonadThrow m, Show l  )
  => Map Name (Class Type Name l)
  -> [TypeSignature Type Name Name]
  -> (ExplicitlyTypedBinding Type Name l)
  -> InferT m ([Predicate Type Name], ExplicitlyTypedBinding Type Name (TypeSignature Type Name l))
inferExplicitlyTypedBindingType ce as (ExplicitlyTypedBinding l (identifier, l') sc alts) = do
  (Qualified qs t) <- freshInst sc
  (ps, alts') <- inferAltTypes ce as alts t
  s <- InferT (gets inferStateSubstitutions)
  let qs' = map (substitutePredicate s) qs
      t' = substituteType s t
      fs =
        getTypeVariablesOf
          getTypeSignatureTypeVariables
          (map (substituteTypeSignature s) as)
      gs = getTypeTypeVariables t' \\ fs
      sc' = quantify gs (Qualified qs' t')
      ps' = filter (not . entail ce qs') (map (substitutePredicate s) ps)
  (ds, rs) <- split ce fs gs ps'
  if not (sc `schemesEquivalent` sc')
    then throwM (ExplicitTypeMismatch sc sc')
    else if not (null rs)
           then throwM ContextTooWeak
           else return
                  ( ds
                  , ExplicitlyTypedBinding
                      (TypeSignature l sc)
                      (identifier, TypeSignature l' sc)
                      sc
                      alts')

-- | Are two type schemes alpha-equivalent?
schemesEquivalent :: Scheme Type Name Type ->  Scheme Type Name Type -> Bool
schemesEquivalent (Forall vs1 q1) (Forall vs2 q2) =
  length vs1 == length vs2 &&
  evalState (goQ q1 q2) (mempty,mempty)
  where
    goQ (Qualified ps1 t1) (Qualified ps2 t2) =
      (&&) <$> fmap and (sequence (zipWith goPred ps1 ps2)) <*> goType t1 t2
    goPred (IsIn x ts1) (IsIn y ts2) =
      ((x == y) &&) <$> fmap and (sequence (zipWith goType ts1 ts2))
    goType (VariableType tv1) (VariableType tv2) = do
      i <- bind fst first tv1
      j <- bind snd second tv2
      pure (i == j)
    goType (ConstructorType c1) (ConstructorType c2) = pure (c1 == c2)
    goType (ApplicationType f1 a1) (ApplicationType f2 a2) =
      (&&) <$> goType f1 f2 <*> goType a1 a2
    goType _ _ = pure False
    bind the upon tv = do
      ctx <- gets the
      case M.lookup tv ctx of
        Nothing -> do
          modify (upon (M.insert tv (M.size ctx)))
          pure (M.size ctx)
        Just j -> pure j

inferImplicitlyTypedBindingsTypes
  :: (MonadThrow m, Show l)
  => Map Name (Class Type Name l)
  -> [(TypeSignature Type Name Name)]
  -> [ImplicitlyTypedBinding Type Name l]
  -> InferT m ([Predicate Type Name], [(TypeSignature Type Name Name)], [ImplicitlyTypedBinding Type Name (TypeSignature Type Name l)])
inferImplicitlyTypedBindingsTypes ce as bs = do
  ts <- mapM (\_ -> newVariableType StarKind) bs
  let is = map (fst . implicitlyTypedBindingId) bs
      scs = map toScheme ts
      as' = zipWith (\x y -> TypeSignature x y) is scs ++ as
  pss0 <-
    sequence
      (zipWith
         (\b t -> inferAltTypes ce as' (implicitlyTypedBindingAlternatives b) t)
         bs
         ts)
  let pss = map fst pss0
      binds' = map snd pss0
  s <- InferT (gets inferStateSubstitutions)
  let ps' = map (substitutePredicate s) (concat pss)
      ts' = map (substituteType s) ts
      fs =
        getTypeVariablesOf
          getTypeSignatureTypeVariables
          (map (substituteTypeSignature s) as)
      vss = map getTypeTypeVariables ts'
      gs = foldr1' union vss \\ fs
  (ds, rs) <- split ce fs (foldr1' intersect vss) ps'
  if restrictImplicitlyTypedBindings bs
    then let gs' = gs \\ getTypeVariablesOf getPredicateTypeVariables rs
             scs' = map (quantify gs' . (Qualified [])) ts'
         in return
              ( ds ++ rs
              , zipWith (\x y -> TypeSignature x y) is scs'
              , zipWith
                  (\(ImplicitlyTypedBinding l (tid, l') _, binds'') scheme ->
                     ImplicitlyTypedBinding
                       (TypeSignature l scheme)
                       (tid, TypeSignature l' scheme)
                       binds'')
                  (zip bs binds')
                  scs')
    else let scs' = map (quantify gs . (Qualified rs)) ts'
         in return
              ( ds
              , zipWith (\x y -> TypeSignature x y) is scs'
              , zipWith
                  (\(ImplicitlyTypedBinding l (tid, l') _, binds'') scheme ->
                     ImplicitlyTypedBinding (TypeSignature l scheme) (tid,TypeSignature l' scheme) binds'')
                  (zip bs binds')
                  scs')
  where
    foldr1' f xs =
      if null xs
        then []
        else foldr1 f xs

inferBindGroupTypes
  :: (MonadThrow m, Show l)
  => Map Name (Class Type Name l)
  -> [(TypeSignature Type Name Name)]
  -> (BindGroup Type Name l)
  -> InferT m ([Predicate Type Name], [(TypeSignature Type Name Name)], BindGroup Type Name (TypeSignature Type Name l))
inferBindGroupTypes ce as (BindGroup es iss) = do
  let as' = [TypeSignature v sc | ExplicitlyTypedBinding _ (v, _) sc _alts <- es]
  (ps, as'', iss') <-
    inferSequenceTypes0 inferImplicitlyTypedBindingsTypes ce (as' ++ as) iss
  qss <- mapM (inferExplicitlyTypedBindingType ce (as'' ++ as' ++ as)) es
  return (ps ++ concat (map fst qss), as'' ++ as', BindGroup (map snd qss) iss')

inferSequenceTypes0
  :: Monad m
  => (Map Name (Class Type Name l) -> [(TypeSignature Type Name Name)] -> [bg l] -> InferT m ([Predicate Type Name], [(TypeSignature Type Name Name)], [bg (TypeSignature Type Name l)]))
  -> Map Name (Class Type Name l)
  -> [(TypeSignature Type Name Name)]
  -> [[bg l]]
  -> InferT m ([Predicate Type Name], [(TypeSignature Type Name Name)], [[bg (TypeSignature Type Name l)]])
inferSequenceTypes0 _ _ _ [] = return ([], [], [])
inferSequenceTypes0 ti ce as (bs:bss) = do
  (ps, as', bs') <- ti ce as bs
  (qs, as'', bss') <- inferSequenceTypes0 ti ce (as' ++ as) bss
  return (ps ++ qs, as'' ++ as', bs' : bss')

inferSequenceTypes
  :: Monad m
  => (Map Name (Class Type Name l) -> [(TypeSignature Type Name Name)] -> bg l -> InferT m ([Predicate Type Name], [(TypeSignature Type Name Name)], bg (TypeSignature Type Name l)))
  -> Map Name (Class Type Name l)
  -> [(TypeSignature Type Name Name)]
  -> [bg l]
  -> InferT m ([Predicate Type Name], [(TypeSignature Type Name Name)], [bg (TypeSignature Type Name l)])
inferSequenceTypes _ _ _ [] = return ([], [], [])
inferSequenceTypes ti ce as (bs:bss) = do
  (ps, as', bs') <- ti ce as bs
  (qs, as'', bss') <- inferSequenceTypes ti ce (as' ++ as) bss
  return (ps ++ qs, as'' ++ as', bs' : bss')

--------------------------------------------------------------------------------
-- Instantiation

instantiateType :: [(TypeVariable Name, Type Name)] -> Type Name -> Type Name
instantiateType ts (ApplicationType l r) =
  ApplicationType (instantiateType ts l) (instantiateType ts r)
instantiateType ts ty@(VariableType tyvar) =
  case lookup tyvar ts of
    Nothing -> ty
    Just ty' -> ty' -- TODO: possibly throw error here?
-- instantiateType ts (GenericType n) = ts !! n
instantiateType _ t = t

instantiateQualified :: [(TypeVariable Name, Type Name)] -> Qualified Type Name (Type Name) -> Qualified Type Name (Type Name)
instantiateQualified ts (Qualified ps t) =
  Qualified (map (instantiatePredicate ts) ps) (instantiateType ts t)

instantiatePredicate :: [(TypeVariable Name, Type Name)] -> Predicate Type Name -> Predicate Type Name
instantiatePredicate ts (IsIn c t) = IsIn c (map (instantiateType ts) t)

--------------------------------------------------------------------------------
-- Type variables

getTypeSignatureTypeVariables :: (TypeSignature Type Name Name) -> [TypeVariable Name]
getTypeSignatureTypeVariables = getTypeVariables where
  getTypeVariables (TypeSignature _  scheme) = getSchemeTypeVariables scheme
    where getSchemeTypeVariables (Forall _ qualified) = getQualifiedTypeVariables qualified

getQualifiedTypeVariables :: Qualified Type Name (Type Name) -> [TypeVariable Name]
getQualifiedTypeVariables = getTypeVariables
  where
    getTypeVariables (Qualified predicates t) =
      getTypeVariablesOf getPredicateTypeVariables predicates `union`
      getTypeTypeVariables t

getPredicateTypeVariables :: Predicate Type Name -> [TypeVariable Name]
getPredicateTypeVariables (IsIn _ types) = getTypeVariablesOf getTypeTypeVariables types

getTypeTypeVariables :: Type Name -> [TypeVariable Name]
getTypeTypeVariables = getTypeVariables where
  getTypeVariables (VariableType typeVariable) = [typeVariable]
  getTypeVariables (ApplicationType type1 type2) =
    getTypeVariables type1 `union` getTypeVariables type2
  getTypeVariables _ = []

getTypeVariablesOf :: (a -> [TypeVariable Name]) -> [a] -> [TypeVariable Name]
getTypeVariablesOf f = nub . concatMap f

-- | Get the kind of a type.
typeKind :: Type Name -> Kind
typeKind (ConstructorType typeConstructor) = typeConstructorKind typeConstructor
typeKind (VariableType typeVariable) = typeVariableKind typeVariable
typeKind (ApplicationType typ _) =
  case (typeKind typ) of
    (FunctionKind _ kind) -> kind
    k -> k

--------------------------------------------------------------------------------
-- GOOD NAMING CONVENInferON, UNSORTED

-- | The monomorphism restriction is invoked when one or more of the
-- entries in a list of implicitly typed bindings is simple, meaning
-- that it has an alternative with no left-hand side patterns. The
-- following function provides a way to test for this:
restrictImplicitlyTypedBindings :: [(ImplicitlyTypedBinding t Name l)] -> Bool
restrictImplicitlyTypedBindings = any simple
  where
    simple =
      any (null . alternativePatterns) . implicitlyTypedBindingAlternatives

-- | The following function calculates the list of ambiguous variables
-- and pairs each one with the list of predicates that must be
-- satisfied by any choice of a default:
ambiguities :: [TypeVariable Name] -> [Predicate Type Name] -> [Ambiguity Name]
ambiguities typeVariables predicates =
  [ Ambiguity typeVariable (filter (elem typeVariable . getPredicateTypeVariables) predicates)
  | typeVariable <- getTypeVariablesOf getPredicateTypeVariables predicates \\ typeVariables
  ]

-- | The unifyTypeVariable function is used for the special case of unifying a
-- variable u with a type t.
unifyTypeVariable :: MonadThrow m => TypeVariable Name -> Type Name -> m [Substitution Name]
unifyTypeVariable typeVariable typ
  | typ == VariableType typeVariable = return nullSubst
  | typeVariable `elem` getTypeTypeVariables typ = throwM OccursCheckFails
  | typeVariableKind typeVariable /= typeKind typ = throwM KindMismatch
  | otherwise = return [Substitution typeVariable typ]

unifyPredicates :: Predicate Type Name -> Predicate Type Name -> Maybe [Substitution Name]
unifyPredicates = lift' unifyTypeList

oneWayMatchPredicate :: Predicate Type Name -> Predicate Type Name -> Maybe [Substitution Name]
oneWayMatchPredicate = lift' oneWayMatchLists

unifyTypes :: MonadThrow m => Type Name -> Type Name -> m [Substitution Name]
unifyTypes (ApplicationType l r) (ApplicationType l' r') = do
              s1 <- unifyTypes l l'
              s2 <- unifyTypes (substituteType s1 r) (substituteType s1 r')
              return (s2 @@ s1)
unifyTypes (VariableType u) t = unifyTypeVariable u t
unifyTypes t (VariableType u) = unifyTypeVariable u t
unifyTypes (ConstructorType tc1) (ConstructorType tc2)
              | tc1 == tc2 = return nullSubst
unifyTypes a b = throwM (TypeMismatch a b)

unifyTypeList :: MonadThrow m => [Type Name] -> [Type Name] -> m [Substitution Name]
unifyTypeList (x:xs) (y:ys) = do
    s1 <- unifyTypes x y
    s2 <- unifyTypeList (map (substituteType s1) xs) (map (substituteType s1) ys)
    return (s2 @@ s1)
unifyTypeList [] [] = return nullSubst
unifyTypeList _ _ = throwM ListsDoNotUnify

oneWayMatchType :: MonadThrow m => Type Name -> Type Name -> m [Substitution Name]
oneWayMatchType (ApplicationType l r) (ApplicationType l' r') = do
  sl <- oneWayMatchType l l'
  sr <- oneWayMatchType r r'
  merge sl sr
oneWayMatchType (VariableType u) t
  | typeVariableKind u == typeKind t = return [Substitution u t]
oneWayMatchType (ConstructorType tc1) (ConstructorType tc2)
  | tc1 == tc2 = return nullSubst
oneWayMatchType _ _ = throwM TypeMismatchOneWay

oneWayMatchLists :: MonadThrow m => [Type Name] -> [Type Name] -> m [Substitution Name]
oneWayMatchLists ts ts' = do
    ss <- sequence (zipWith oneWayMatchType ts ts')
    foldM merge nullSubst ss

--------------------------------------------------------------------------------
-- Garbage

lookupName
  :: MonadThrow m
  => Name -> [(TypeSignature Type Name Name)] -> m (Scheme Type Name Type)
lookupName name cands = go name cands where
  go n [] = throwM (NotInScope cands n)
  go i ((TypeSignature i'  sc):as) =
    if i == i'
      then return sc
      else go i as

enumId :: Int -> Name
enumId n = ForallName n

inferLiteralType
  :: Monad m
  => SpecialTypes Name -> Literal -> InferT m ([Predicate Type Name], Type Name)
inferLiteralType specialTypes (CharacterLiteral _) =
  return ([], ConstructorType (specialTypesChar specialTypes))
inferLiteralType specialTypes (IntegerLiteral _) = do
  return ([], ConstructorType (specialTypesInteger specialTypes))
inferLiteralType specialTypes (StringLiteral _) =
  return ([], ConstructorType (specialTypesString specialTypes))
inferLiteralType specialTypes (RationalLiteral _) = do
  return ([], ConstructorType (specialTypesRational specialTypes))

inferPattern
  :: MonadThrow m
  => [TypeSignature Type Name Name] -> Pattern Type Name l
  -> InferT m (Pattern Type Name (TypeSignature Type Name l), [Predicate Type Name], [(TypeSignature Type Name Name)], Type Name)
inferPattern signatures = go
  where
    go (BangPattern p) = do
      (p', x, y, z) <- go p
      pure (BangPattern p', x, y, z)
    go (VariablePattern l i) = do
      v <- newVariableType StarKind
      return
        ( VariablePattern (TypeSignature l (toScheme v)) i
        , []
        , [TypeSignature i (toScheme v)]
        , v)
    go (WildcardPattern l s) = do
      v <- newVariableType StarKind
      return (WildcardPattern (TypeSignature l (toScheme v)) s, [], [], v)
    go (AsPattern l i pat) = do
      (pat', ps, as, t) <- go pat
      return
        ( AsPattern (TypeSignature l (toScheme t)) i pat'
        , ps
        , (TypeSignature i (toScheme t)) : as
        , t)
    go (LiteralPattern l0 l) = do
      specialTypes <- InferT (gets inferStateSpecialTypes)
      (ps, t) <- inferLiteralType specialTypes l
      return (LiteralPattern (TypeSignature l0 (toScheme t)) l, ps, [], t)
    go (ConstructorPattern l i pats) = do
      TypeSignature _ sc <- substituteConstr signatures i
      (pats', ps, as, ts) <- inferPatterns signatures pats
      t' <- newVariableType StarKind
      (Qualified qs t) <- freshInst sc
      specialTypes <- InferT (gets inferStateSpecialTypes)
      let makeArrow :: Type Name -> Type Name -> Type Name
          a `makeArrow` b =
            ApplicationType
              (ApplicationType
                 (ConstructorType (specialTypesFunction specialTypes))
                 a)
              b
      unify t (foldr makeArrow t' ts)
      return
        ( ConstructorPattern (TypeSignature l (toScheme t')) i pats'
        , ps ++ qs
        , as
        , t')
-- inferPattern (LazyPattern pat) = inferPattern pat

substituteConstr
  :: MonadThrow m
  => [TypeSignature Type Name Name] -> Name -> m (TypeSignature Type Name Name)
substituteConstr subs i =
  case find
         (\case
            TypeSignature i' _ -> i' == i)
         subs of
    Just sig -> pure sig
    _ ->
      throwM
        (NameNotInConScope
           (filter
              (\case
                 TypeSignature (ConstructorName _ _) _ -> True
                 _ -> False)
              subs)
           i)

inferPatterns
  :: MonadThrow m
  => [TypeSignature Type Name Name] -> [Pattern Type Name l] -> InferT m ([Pattern Type Name (TypeSignature Type Name l)], [Predicate Type Name], [(TypeSignature Type Name Name)], [Type Name])
inferPatterns ss pats = do
  psasts <- mapM (inferPattern ss) pats
  let ps = concat [ps' | (_,ps', _, _) <- psasts]
      as = concat [as' | (_,_, as', _) <- psasts]
      ts = [t | (_, _, _, t) <- psasts]
      pats' = [ p | (p,_,_,_) <- psasts]
  return (pats', ps, as, ts)

predHead :: Predicate Type Name -> Name
predHead (IsIn i _) = i

lift'
  :: MonadThrow m
  => ([Type Name] -> [Type Name] -> m a) -> Predicate Type Name -> Predicate Type Name -> m a
lift' m (IsIn i ts) (IsIn i' ts')
  | i == i' = m ts ts'
  | otherwise = throwM ClassMismatch

-- lookupClassTypeVariables :: Map Name (Class Type Name l) -> Name -> [TypeVariable Name]
-- lookupClassTypeVariables ce i =
--   fromMaybe
--     []
--     (fmap classTypeVariables (M.lookup i ce))

-- lookupClassSuperclasses :: Map Name (Class Type Name l) -> Name -> [Predicate Type Name]
-- lookupClassSuperclasses ce i = maybe [] classSuperclasses (M.lookup i ce)

-- lookupClassMethods :: Map Name (Class Type Name l) -> Name -> Map Name (Type Name)
-- lookupClassMethods ce i = maybe mempty classMethods (M.lookup i ce)

-- lookupClassInstances :: Map Name (Class Type Name l) -> Name -> [Instance Type Name l]
-- lookupClassInstances ce i =
--   maybe [] classInstances (M.lookup i ce)

defined :: Maybe a -> Bool
defined (Just _) = True
defined Nothing = False


-- | Add a class to the environment. Example:
--
-- @
-- env <- addClass (Name l \"Num\") [TypeVariable (Name \"n\") StarKind] [] mempty
-- @
--
-- Throws 'ReadException' in the case of error.
addClass
  :: MonadThrow m
  => Class Type Name l
  -> Map Name (Class Type Name l)
  -> m (Map Name (Class Type Name l))
addClass (Class vs ps _ i methods) ce
  | defined (M.lookup i ce) = throwM ClassAlreadyDefined
  | any (not . defined . flip M.lookup ce . predHead) ps =
    throwM UndefinedSuperclass
  | otherwise = return (M.insert i (Class vs ps [] i methods) ce)


-- | Add an instance of a class. Example:
--
-- @
-- env <- addInstance [] (IsIn (Name \"Num\") [ConstructorType (TypeConstructor (Name \"Integer\") StarKind)]) mempty
-- @
--
-- Throws 'ReadException' in the case of error.
addInstance
  :: MonadThrow m
  => Instance Type Name l
  -> Map Name (Class Type Name l)
  -> m (Map Name (Class Type Name l))
addInstance (Instance (Forall vs (Qualified preds p@(IsIn i _))) dict) ce =
  case M.lookup i ce of
    Nothing -> throwM NoSuchClassForInstance
    Just typeClass
      | any (overlap p) qs -> throwM OverlappingInstance
      | otherwise -> return (M.insert i c ce)
      where its = classInstances typeClass
            qs = [q | Instance (Forall _ (Qualified _ q)) _ <- its]
            ps = []
            c =
              (Class
                 (classTypeVariables typeClass)
                 (classSuperclasses typeClass)
                 (Instance (Forall vs (Qualified (nub (ps ++ preds)) p)) dict :
                  its)
                 i
                 (classMethods typeClass))

overlap :: Predicate Type Name -> Predicate Type Name -> Bool
overlap p q = defined (unifyPredicates p q)

bySuper :: Map Name (Class Type Name l) -> Predicate Type Name -> [Predicate Type Name]
bySuper ce p@(IsIn i ts) = p : concat (map (bySuper ce) supers)
  where
    supers =
      map
        (substitutePredicate substitutions)
        (maybe [] classSuperclasses (M.lookup i ce))
    substitutions =
      zipWith Substitution (maybe [] classTypeVariables (M.lookup i ce)) ts

byInst
  :: Map Name (Class Type Name l)
  -> Predicate Type Name
  -> Maybe ([Predicate Type Name], Dictionary Type Name l)
byInst ce p@(IsIn i _) =
  case M.lookup i ce of
    Nothing -> throwM NoSuchClassForInstance
    Just typeClass ->
      (msum [tryInst it | it <- classInstances typeClass])
  where
    tryInst (Instance (Forall _ (Qualified ps h)) dict) = do
      (return ())
      case oneWayMatchPredicate h p of
        Just u ->
          (Just (map (substitutePredicate u) ps, dict))
        Nothing -> Nothing

entail :: Show l =>  Map Name (Class Type Name l) -> [Predicate Type Name] -> Predicate Type Name -> Bool
entail ce ps p =
  any (p `elem`) (map (bySuper ce) ps) ||
  case byInst ce p of
    Nothing -> False
    Just (qs, _) -> all (entail ce ps) qs

simplify :: ([Predicate Type Name] -> Predicate Type Name -> Bool) -> [Predicate Type Name] -> [Predicate Type Name]
simplify ent = loop []
  where
    loop rs [] = rs
    loop rs (p:ps)
      | ent (rs ++ ps) p = loop rs ps
      | otherwise = loop (p : rs) ps

reduce :: Show l => Map Name (Class Type Name l) -> [Predicate Type Name] -> [Predicate Type Name]
reduce ce = simplify (scEntail ce) . elimTauts ce

elimTauts :: Show l => Map Name (Class Type Name l) -> [Predicate Type Name] -> [Predicate Type Name]
elimTauts ce ps = [p | p <- ps, not (entail ce [] p)]

scEntail :: Map Name (Class Type Name l) -> [Predicate Type Name] -> Predicate Type Name -> Bool
scEntail ce ps p = any (p `elem`) (map (bySuper ce) ps)

quantify :: [TypeVariable Name] -> Qualified Type Name (Type Name) -> Scheme Type Name Type
quantify vs qt = Forall vs' qt
  where
    vs' = [v | v <- getQualifiedTypeVariables qt, v `elem` vs]
    {-ks = map typeVariableKind vs'-}
    {-s = zipWith Substitution vs' (map undefined {-GenericType-} [0 ..])-}

toScheme :: Type Name -> Scheme Type Name Type
toScheme t = Forall [] (Qualified [] t)

merge
  :: MonadThrow m
  => [Substitution Name] -> [Substitution Name] -> m [Substitution Name]
merge s1 s2 =
  if agree
    then return (s1 ++ s2)
    else throwM MergeFail
  where
    agree =
      all
        (\v -> substituteType s1 (VariableType v) == substituteType s2 (VariableType v))
        (map substitutionTypeVariable s1 `intersect`
         map substitutionTypeVariable s2)

inferExpressionType
  :: (MonadThrow m, Show l)
  => Map Name (Class Type Name l)
  -> [(TypeSignature Type Name Name)]
  -> (Expression Type Name l)
  -> InferT m ([Predicate Type Name], Type Name, Expression Type Name (TypeSignature Type Name l))
inferExpressionType ce as (ParensExpression l e) = do
  (ps, t, e') <- inferExpressionType ce as e
  pure (ps, t, ParensExpression (fmap (const l) (expressionLabel e')) e')
inferExpressionType _ as (VariableExpression l i) = do
  sc <- lookupName i as
  qualified@(Qualified ps t) <- freshInst sc
  let scheme = (Forall [] qualified)
  return (ps, t, VariableExpression (TypeSignature l scheme) i)
inferExpressionType _ _ (ConstantExpression l i) = do
  t <- newVariableType StarKind
  return ([], t, (ConstantExpression (TypeSignature l (toScheme t)) i))
inferExpressionType _ as (ConstructorExpression l i) = do
  sc <- lookupName i as
  qualified@(Qualified ps t) <- freshInst sc
  let scheme = (Forall [] qualified)
  return (ps, t, ConstructorExpression (TypeSignature l scheme) i)
inferExpressionType _ _ (LiteralExpression l0 l) = do
  specialTypes <- InferT (gets inferStateSpecialTypes)
  (ps, t) <- inferLiteralType specialTypes l
  let scheme = (Forall [] (Qualified ps t))
  return (ps, t, LiteralExpression (TypeSignature l0 scheme) l)
inferExpressionType ce as (ApplicationExpression l e f) = do
  (ps, te, e') <- inferExpressionType ce as e
  (qs, tf, f') <- inferExpressionType ce as f
  t <- newVariableType StarKind
  specialTypes <- InferT (gets inferStateSpecialTypes)
  let makeArrow :: Type Name -> Type  Name -> Type  Name
      a `makeArrow` b = ApplicationType (ApplicationType (ConstructorType(specialTypesFunction specialTypes)) a) b
  unify (tf `makeArrow` t) te
  let scheme = (Forall [] (Qualified (ps++qs) t))
  return (ps ++ qs, t, ApplicationExpression (TypeSignature l scheme) e' f')
inferExpressionType ce as (InfixExpression l x (i,op) y) = do
  (ps, ts, ~(ApplicationExpression l' (ApplicationExpression _ (op') x') y')) <-
    inferExpressionType
      ce
      as
      (ApplicationExpression l (ApplicationExpression l op x) y)
  pure (ps, ts, InfixExpression l' x' (i, op') y')
inferExpressionType ce as (LetExpression l bg e) = do
  (ps, as', bg') <- inferBindGroupTypes ce as bg
  (qs, t, e') <- inferExpressionType ce (as' ++ as) e
  let scheme = (Forall [] (Qualified (ps++qs) t))
  return (ps ++ qs, t, LetExpression (TypeSignature l scheme) bg' e')
inferExpressionType ce as (LambdaExpression l alt) = do
  (x, y, s) <- inferAltTypeForLambda ce as alt
  pure
    ( x
    , y
    , LambdaExpression
        (TypeSignature l (typeSignatureScheme (alternativeLabel s)))
        s)
inferExpressionType ce as (IfExpression l e e1 e2) = do
  (ps, t, e') <- inferExpressionType ce as e
  specialTypes <- InferT (gets inferStateSpecialTypes)
  unify t (dataTypeConstructor (specialTypesBool specialTypes))
  (ps1, t1, e1') <- inferExpressionType ce as e1
  (ps2, t2, e2') <- inferExpressionType ce as e2
  unify t1 t2
  let scheme = (Forall [] (Qualified (ps ++ ps1 ++ ps2) t1))
  return (ps ++ ps1 ++ ps2, t1, IfExpression (TypeSignature l scheme) e' e1' e2')
inferExpressionType ce as (CaseExpression l e branches) = do
  (ps0, t, e') <- inferExpressionType ce as e
  v <- newVariableType StarKind
  let tiBr (CaseAlt l' pat f) = do
        (pat', ps, as', t') <- inferPattern as pat
        unify t t'
        (qs, t'', f') <- inferExpressionType ce (as' ++ as) f
        unify v t''
        return
          (ps ++ qs, (CaseAlt (fmap (const l') (expressionLabel f')) pat' f'))
  branchs <- mapM tiBr branches
  let pss = map fst branchs
      branches' = map snd branchs
  let scheme = (Forall [] (Qualified (ps0 ++ concat pss) v))
  return
    (ps0 ++ concat pss, v, CaseExpression (TypeSignature l scheme) e' branches')

inferAltTypeForLambda
  :: (MonadThrow m, Show l)
  => Map Name (Class Type Name l)
  -> [(TypeSignature Type Name Name)]
  -> Alternative Type Name l
  -> InferT m ([Predicate Type Name], Type Name, Alternative Type Name (TypeSignature Type Name l))
inferAltTypeForLambda ce as alt =
  inferAltType0
    ce
    as
    (\l scheme pats ex -> Alternative (TypeSignature l scheme) pats ex)
    alt

inferAltTypeForBind
  :: (MonadThrow m, Show l)
  => Map Name (Class Type Name l)
  -> [(TypeSignature Type Name Name)]
  -> Alternative Type Name l
  -> InferT m ([Predicate Type Name], Type Name, Alternative Type Name (TypeSignature Type Name l))
inferAltTypeForBind ce as alt =
  inferAltType0 ce as makeAltForDecl alt

inferAltType0
  :: (Show t1, MonadThrow m)
  => Map Name (Class Type Name t1)
  -> [TypeSignature Type Name Name]
  -> (t1 -> Scheme Type Name Type -> [Pattern Type Name (TypeSignature Type Name t1)] -> Expression Type Name (TypeSignature Type Name t1) -> t)
  -> Alternative Type Name t1
  -> InferT m ([Predicate Type Name], Type Name, t)
inferAltType0 ce as makeAlt (Alternative l pats e) = do
  (pats', ps, as', ts) <- inferPatterns as pats
  (qs, t, e') <- inferExpressionType ce (as' ++ as) e
  specialTypes <- InferT (gets inferStateSpecialTypes)
  let makeArrow :: Type Name -> Type Name -> Type Name
      a `makeArrow` b = ApplicationType (ApplicationType (ConstructorType(specialTypesFunction specialTypes)) a) b
  let scheme = (Forall [] (Qualified (ps ++ qs) (foldr makeArrow t ts)))
  return (ps ++ qs, foldr makeArrow t ts, makeAlt l scheme pats' e')

-- | During parsing, we parse
-- f = \x -> x
-- as
-- f x = x
-- After type-checking, we expand the lambda out again:
--
-- f = \x -> x
--
-- But type-checked and generalized.
makeAltForDecl
  :: a
  -> Scheme Type i1 Type
  -> [Pattern Type i (TypeSignature Type i1 a)]
  -> Expression Type i (TypeSignature Type i1 a)
  -> Alternative Type i (TypeSignature Type i1 a)
makeAltForDecl l scheme pats' e' =
  if null pats'
    then Alternative (TypeSignature l scheme) pats' e'
    else Alternative
           (TypeSignature l scheme)
           []
           (LambdaExpression
              (TypeSignature l scheme)
              (Alternative (TypeSignature l scheme) pats' e'))

inferAltTypes
  :: (MonadThrow m, Show l)
  => Map Name (Class Type Name l)
  -> [(TypeSignature Type Name Name)]
  -> [Alternative Type Name l]
  -> Type Name
  -> InferT m ([Predicate Type Name], [Alternative Type Name (TypeSignature Type Name l)])
inferAltTypes ce as alts t = do
  psts <- mapM (inferAltTypeForBind ce as) alts
  mapM_ (unify t) (map snd3 psts)
  return (concat (map fst3 psts), map thd3 psts)
  where snd3 (_,x,_) = x
        thd3 (_,_,x) = x
        fst3 (x,_,_) = x

split
  :: (MonadThrow m, Show l)
  => Map Name (Class Type Name l) -> [TypeVariable Name] -> [TypeVariable Name] -> [Predicate Type Name] -> m ([Predicate Type Name], [Predicate Type Name])
split ce fs gs ps = do
  let ps' = reduce ce ps
      (ds, rs) = partition (all (`elem` fs) . getPredicateTypeVariables) ps'
  rs' <- defaultedPredicates ce (fs ++ gs) rs
  return (ds, rs \\ rs')

candidates :: (Show l)=> Map Name (Class Type Name l) -> Ambiguity Name -> [Type Name]
candidates ce (Ambiguity v qs) =
  [ t'
  | let is = [i | IsIn i _ <- qs]
        ts = [t | IsIn _ t <- qs]
  , all ([VariableType v] ==) ts
  , any (`elem` numClasses) is
  , all (`elem` stdClasses) is
  , t' <- [VariableType (TypeVariable (TypeName (-1) "x") StarKind)]-- classEnvironmentDefaults ce
  , all (entail ce []) [IsIn i [t'] | i <- is]
  ]
  where -- disabling these
        numClasses = [ForallName (-1)]
        stdClasses = [ForallName (-1)]


withDefaults
  :: (MonadThrow m, Show l)
  => String
  -> ([Ambiguity Name] -> [Type Name] -> a)
  -> Map Name (Class Type Name l)
  -> [TypeVariable Name]
  -> [Predicate Type Name]
  -> m a
withDefaults _label f ce vs ps
  | any null tss = throwM (AmbiguousInstance vps)
  | otherwise = do
    return (f vps (map head tss))
  where
    -- showp :: Show a => a -> String
    -- showp = \x -> "(" ++ show x ++ ")"
    vps = ambiguities vs ps
    tss = map (candidates ce) vps

defaultedPredicates
  :: (MonadThrow m, Show l)
  => Map Name (Class Type Name l) -> [TypeVariable Name] -> [Predicate Type Name] -> m [Predicate Type Name]
defaultedPredicates = withDefaults "defaultedPredicates" (\vps _ -> concat (map ambiguityPredicates vps))

defaultSubst
  :: (MonadThrow m, Show l)
  => Map Name (Class Type Name l) -> [TypeVariable Name] -> [Predicate Type Name] -> m [Substitution Name]
defaultSubst = withDefaults "defaultSubst" (\vps ts -> zipWith Substitution (map ambiguityTypeVariable vps) ts)

-- extSubst
--   :: Monad m
--   => [Substitution] -> InferT m ()
-- extSubst s' =
--   InferT
--     (modify
--        (\s -> s {inferStateSubstitutions = s' @@ inferStateSubstitutions s}))

freshInst
  :: Monad m
  => Scheme Type Name Type -> InferT m (Qualified Type Name (Type Name))
freshInst (Forall ks qt) = do
  ts <- mapM (\vorig -> (vorig, ) <$> newVariableType (typeVariableKind vorig)) ks
  return (instantiateQualified ts qt)


================================================
FILE: src/Duet/Parser.hs
================================================
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE FlexibleContexts #-}
-- |

module Duet.Parser where

import           Control.Monad
import           Control.Monad.Catch
import           Control.Monad.IO.Class
import           Data.List
import qualified Data.Map.Strict as M
import           Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.IO as T
import           Duet.Printer
import           Duet.Tokenizer
import           Duet.Types
import           Text.Parsec hiding (satisfy, anyToken)

parseFile :: (MonadIO m, MonadThrow m) => FilePath -> m [Decl UnkindedType Identifier Location]
parseFile fp = do
  t <- liftIO (T.readFile fp)
  parseText fp t

parseText :: MonadThrow m => SourceName -> Text -> m [Decl UnkindedType Identifier Location]
parseText fp inp =
  case parse tokensTokenizer fp (inp) of
    Left e -> throwM (TokenizerError e)
    Right tokens' ->
      case runParser tokensParser 0 fp tokens' of
        Left e -> throwM (ParserError e)
        Right ast -> pure ast

parseTextWith
  :: (Num u, MonadThrow m)
  => Parsec [(Token, Location)] u a -> SourceName -> Text -> m a
parseTextWith p fp inp =
  case parse tokensTokenizer fp (inp) of
    Left e -> throwM (TokenizerError e)
    Right tokens' ->
      case runParser p 0 fp tokens' of
        Left e -> throwM (ParserError e)
        Right ast -> pure ast

parseType' :: Num u => SourceName -> Parsec [(Token, Location)] u b -> Text -> Either ParseError b
parseType' fp p inp =
  case parse tokensTokenizer fp (inp) of
    Left e -> Left e
    Right tokens' ->
      case runParser p 0 fp tokens' of
        Left e -> Left e
        Right ast -> Right ast

tokensParser :: TokenParser [Decl UnkindedType Identifier Location]
tokensParser = moduleParser <* endOfTokens

moduleParser :: TokenParser [Decl UnkindedType Identifier Location]
moduleParser =
  many
    (varfundeclExplicit <|> fmap (uncurry DataDecl) datadecl <|>
     fmap (uncurry ClassDecl) classdecl <|>
     fmap (uncurry InstanceDecl) instancedecl)

classdecl :: TokenParser (Location, Class UnkindedType Identifier Location)
classdecl =
  go <?> "class declaration (e.g. class Show a where show a :: a -> String)"
  where
    go = do
      u <- getState
      loc <- equalToken ClassToken
      setState (locationStartColumn loc)
      (c, _) <-
        consumeToken
          (\case
             Constructor c -> Just c
             _ -> Nothing) <?>
        "new class name e.g. Show"
      vars <- many1 kindableTypeVariable
      mwhere <-
        fmap (const True) (equalToken Where) <|> fmap (const False) endOfDecl
      methods <-
        if mwhere
          then do
            (_, identLoc) <-
              lookAhead
                (consumeToken
                   (\case
                      Variable i -> Just i
                      _ -> Nothing)) <?>
              "class methods e.g. foo :: a -> Int"
            (many1 (methodParser (locationStartColumn identLoc))) <* endOfDecl
          else (pure [])
      setState u
      _ <- (pure () <* satisfyToken (== NonIndentedNewline)) <|> endOfTokens
      pure
        ( loc
        , Class
          { className = Identifier (T.unpack c)
          , classTypeVariables = vars
          , classSuperclasses = []
          , classInstances = []
          , classMethods = M.fromList methods
          })
      where
        endOfDecl =
          (pure () <* satisfyToken (== NonIndentedNewline)) <|> endOfTokens
        methodParser startCol = go' <?> "method signature e.g. foo :: a -> Y"
          where
            go' = do
              u <- getState
              (v, p) <-
                consumeToken
                  (\case
                     Variable i -> Just i
                     _ -> Nothing)
              when
                (locationStartColumn p /= startCol)
                (unexpected
                   ("method name at column " ++
                    show (locationStartColumn p) ++
                    ", it should start at column " ++
                    show startCol ++ " to match the others"))
              setState startCol
              _ <- equalToken Colons <?> "‘::’ for method signature"
              scheme <- parseScheme <?> "method type signature e.g. foo :: Int"
              setState u
              pure (Identifier (T.unpack v), scheme)

kindableTypeVariable :: Stream s m (Token, Location) => ParsecT s Int m (TypeVariable Identifier)
kindableTypeVariable = (unkinded <|> kinded) <?> "type variable (e.g. ‘a’, ‘f’, etc.)"
  where
    kinded =
      kparens
        (do t <- unkinded
            _ <- equalToken Colons
            k <- kindParser
            pure (TypeVariable (typeVariableIdentifier t) k))
      where
        kparens :: TokenParser a -> TokenParser a
        kparens p = g <?> "parens e.g. (x)"
          where
            g = do
              _ <- equalToken OpenParen
              e <-
                p <?> "type with kind inside parentheses e.g. (t :: Type)"
              _ <- equalToken CloseParen <?> "closing parenthesis ‘)’"
              pure e
    unkinded = do
      (v, _) <-
        consumeToken
          (\case
             Variable i -> Just i
             _ -> Nothing) <?>
        "variable name"
      pure (TypeVariable (Identifier (T.unpack v)) StarKind)

parseScheme
  :: Stream s m (Token, Location)
  => ParsecT s Int m (Scheme UnkindedType Identifier UnkindedType)
parseScheme = do
  explicit <-
    fmap (const True) (lookAhead (equalToken ForallToken)) <|> pure False
  if explicit
    then quantified
    else do
      ty@(Qualified _ qt) <- parseQualified
      pure (Forall (nub (collectTypeVariables qt)) ty)
  where
    quantified = do
      _ <- equalToken ForallToken
      vars <- many1 kindableTypeVariable <?> "type variables"
      _ <- equalToken Period
      ty <- parseQualified
      pure (Forall vars ty)

parseSchemePredicate
  :: Stream s m (Token, Location)
  => ParsecT s Int m (Scheme UnkindedType Identifier (Predicate UnkindedType))
parseSchemePredicate = do
  explicit <-
    fmap (const True) (lookAhead (equalToken ForallToken)) <|> pure False
  if explicit
    then quantified
    else do
      ty@(Qualified _ (IsIn _ qt)) <- parseQualifiedPredicate
      pure (Forall (nub (concatMap collectTypeVariables qt)) ty)
  where
    quantified = do
      _ <- equalToken ForallToken
      vars <- many1 kindableTypeVariable <?> "type variables"
      _ <- equalToken Period
      ty <- parseQualifiedPredicate
      pure (Forall vars ty)

parseQualified
  :: Stream s m (Token, Location)
  => ParsecT s Int m (Qualified UnkindedType Identifier (UnkindedType Identifier))
parseQualified = do
  ty <- parsedTypeLike
  (case ty of
     ParsedQualified ps x -> Qualified <$> mapM toUnkindedPred ps <*> toType x
       where toUnkindedPred (IsIn c ts) = IsIn c <$> mapM toType ts
     _ -> do
       t <- toType ty
       pure (Qualified [] t)) <?>
    "qualified type e.g. Show x => x"

parseQualifiedPredicate
  :: Stream s m (Token, Location)
  => ParsecT s Int m (Qualified UnkindedType Identifier (Predicate UnkindedType Identifier))
parseQualifiedPredicate = do
  ty <- parsedTypeLike
  (case ty of
     ParsedQualified ps x -> Qualified <$> mapM toUnkindedPred ps <*> toPredicateUnkinded x
       where toUnkindedPred (IsIn c ts) = IsIn c <$> mapM toType ts
     _ -> do
       t <- toPredicateUnkinded ty
       pure (Qualified [] t)) <?>
    "qualified type e.g. Show x => x"

collectTypeVariables :: UnkindedType i -> [TypeVariable i]
collectTypeVariables =
  \case
     UnkindedTypeConstructor {} -> []
     UnkindedTypeVariable i -> [TypeVariable i StarKind]
     UnkindedTypeApp f x -> collectTypeVariables f ++ collectTypeVariables x

instancedecl :: TokenParser (Location, Instance UnkindedType Identifier Location)
instancedecl =
  go <?> "instance declaration (e.g. instance Show Int where show = ...)"
  where
    go = do
      u <- getState
      loc <- equalToken InstanceToken
      setState (locationStartColumn loc)
      predicate@(Forall _ (Qualified _ (IsIn (Identifier c) _))) <-
        parseSchemePredicate
      mwhere <-
        fmap (const True) (equalToken Where) <|> fmap (const False) endOfDecl
      methods <-
        if mwhere
          then do
            (_, identLoc) <-
              lookAhead
                (consumeToken
                   (\case
                      Variable i -> Just i
                      _ -> Nothing)) <?>
              "instance methods e.g. foo :: a -> Int"
            (many1 (methodParser (locationStartColumn identLoc))) <* endOfDecl
          else (pure [])
      setState u
      _ <- (pure () <* satisfyToken (== NonIndentedNewline)) <|> endOfTokens
      let dictName = "$dict" ++ c
      pure
        ( loc
        , Instance
          { instancePredicate = predicate
          , instanceDictionary =
              Dictionary (Identifier dictName) (M.fromList methods)
          })
      where
        endOfDecl =
          (pure () <* satisfyToken (== NonIndentedNewline)) <|> endOfTokens
        methodParser startCol =
          go' <?> "method implementation e.g. foo = \\x -> f x"
          where
            go' = do
              u <- getState
              (v, p) <-
                consumeToken
                  (\case
                     Variable i -> Just i
                     _ -> Nothing)
              when
                (locationStartColumn p /= startCol)
                (unexpected
                   ("method name at column " ++
                    show (locationStartColumn p) ++
                    ", it should start at column " ++
                    show startCol ++ " to match the others"))
              setState startCol
              _ <- equalToken Equals <?> "‘=’ for method declaration e.g. x = 1"
              e <- expParser
              setState u
              pure (Identifier (T.unpack v), (p, makeAlt (expressionLabel e) e))

parseType :: Stream s m (Token, Location) => ParsecT s Int m (UnkindedType Identifier)
parseType = do
  x <- parsedTypeLike
  toType x

toPredicateUnkinded :: Stream s m t => ParsedType i -> ParsecT s u m (Predicate UnkindedType i)
toPredicateUnkinded = toPredicate >=> go
  where go (IsIn c tys) = IsIn c <$> mapM toType tys

toType :: Stream s m t => ParsedType i -> ParsecT s u m (UnkindedType i)
toType = go
  where
    go =
      \case
        ParsedTypeConstructor i -> pure (UnkindedTypeConstructor i)
        ParsedTypeVariable i -> pure (UnkindedTypeVariable i)
        ParsedTypeApp t1 t2 -> UnkindedTypeApp <$> go t1 <*> go t2
        ParsedQualified {} -> unexpected "qualification context"
        ParsedTuple {} -> unexpected "tuple"

datadecl :: TokenParser (Location, DataType UnkindedType Identifier)
datadecl = go <?> "data declaration (e.g. data Maybe a = Just a | Nothing)"
  where
    go = do
      loc <- equalToken Data
      (v, _) <-
        consumeToken
          (\case
             Constructor i -> Just i
             _ -> Nothing) <?>
        "new type name (e.g. Foo)"
      vs <- many kindableTypeVariable
      _ <- equalToken Equals
      cs <- sepBy1 consp (equalToken Bar)
      _ <- (pure () <* satisfyToken (== NonIndentedNewline)) <|> endOfTokens
      pure (loc, DataType (Identifier (T.unpack v)) vs cs)

kindParser :: Stream s m (Token, Location) => ParsecT s Int m Kind
kindParser = infix'
  where
    infix' = do
      left <- star
      tok <-
        fmap Just (operator <?> ("arrow " ++ curlyQuotes "->")) <|> pure Nothing
      case tok of
        Just (RightArrow, _) -> do
          right <-
            kindParser <?>
            ("right-hand side of function arrow " ++ curlyQuotes "->")
          pure (FunctionKind left right)
        _ -> pure left
      where
        operator =
          satisfyToken
            (\case
               RightArrow {} -> True
               _ -> False)
    star = do
      (c, _) <-
        consumeToken
          (\case
             Constructor c
               | c == "Type" -> Just StarKind
             _ -> Nothing)
      pure c

consp :: TokenParser (DataTypeConstructor UnkindedType Identifier)
consp = do c <- consParser
           slots <- many slot
           pure (DataTypeConstructor c slots)
  where consParser = go <?> "value constructor (e.g. Just)"
          where
            go = do
              (c, _) <-
                consumeToken
                  (\case
                     Constructor c -> Just c
                     _ -> Nothing)
              pure
                (Identifier (T.unpack c))

slot :: TokenParser (UnkindedType Identifier)
slot = consParser <|> variableParser <|> parens parseType
  where
    variableParser = go <?> "type variable (e.g. ‘a’, ‘s’, etc.)"
      where
        go = do
          (v, _) <-
            consumeToken
              (\case
                 Variable i -> Just i
                 _ -> Nothing)
          pure (UnkindedTypeVariable (Identifier (T.unpack v)))
    consParser = go <?> "type constructor (e.g. Maybe)"
      where
        go = do
          (c, _) <-
            consumeToken
              (\case
                 Constructor c -> Just c
                 _ -> Nothing)
          pure (UnkindedTypeConstructor (Identifier (T.unpack c)))

data ParsedType i
  = ParsedTypeConstructor i
  | ParsedTypeVariable i
  | ParsedTypeApp (ParsedType i) (ParsedType i)
  | ParsedQualified [Predicate ParsedType i] (ParsedType i)
  | ParsedTuple [ParsedType i]
  deriving (Show)

parsedTypeLike :: TokenParser (ParsedType Identifier)
parsedTypeLike = infix' <|> app <|> unambiguous
  where
    infix' = do
      left <- (app <|> unambiguous) <?> "left-hand side of function arrow"
      tok <-
        fmap Just (operator <?> ("function arrow " ++ curlyQuotes "->")) <|>
        fmap Just (operator2 <?> ("constraint arrow " ++ curlyQuotes "=>")) <|>
        pure Nothing
      case tok of
        Just (RightArrow, _) -> do
          right <-
            parsedTypeLike <?>
            ("right-hand side of function arrow " ++ curlyQuotes "->")
          pure
            (ParsedTypeApp
               (ParsedTypeApp (ParsedTypeConstructor (Identifier "(->)")) left)
               right)
        Just (Imply, _) -> do
          left' <- parsedTypeToPredicates left <?> "constraints e.g. Show a or (Read a, Show a)"
          right <-
            parsedTypeLike <?>
            ("right-hand side of constraints " ++ curlyQuotes "=>")
          pure (ParsedQualified left' right)
        _ -> pure left
      where
        operator =
          satisfyToken
            (\case
               RightArrow {} -> True
               _ -> False)
        operator2 =
          satisfyToken
            (\case
               Imply {} -> True
               _ -> False)
    app = do
      f <- unambiguous
      args <- many unambiguous
      pure (foldl' ParsedTypeApp f args)
    unambiguous =
      atomicType <|>
      parensTy
        (do xs <- sepBy1 parsedTypeLike (equalToken Comma)
            case xs of
              [x] -> pure x
              _ -> pure (ParsedTuple xs))
    atomicType = consParse <|> varParse
    consParse = do
      (v, _) <-
        consumeToken
          (\case
             Constructor i -> Just i
             _ -> Nothing) <?>
        "type constructor (e.g. Int, Maybe)"
      pure (ParsedTypeConstructor (Identifier (T.unpack v)))
    varParse = do
      (v, _) <-
        consumeToken
          (\case
             Variable i -> Just i
             _ -> Nothing) <?>
        "type variable (e.g. a, f)"
      pure (ParsedTypeVariable (Identifier (T.unpack v)))
    parensTy p = go <?> "parentheses e.g. (T a)"
      where
        go = do
          _ <- equalToken OpenParen
          e <- p <?> "type inside parentheses e.g. (Maybe a)"
          _ <- equalToken CloseParen <?> "closing parenthesis ‘)’"
          pure e

parsedTypeToPredicates :: Stream s m t => ParsedType i -> ParsecT s u m [Predicate ParsedType i]
parsedTypeToPredicates =
  \case
    ParsedTuple xs -> mapM toPredicate xs
    x -> fmap return (toPredicate x)

toPredicate :: Stream s m t => ParsedType i -> ParsecT s u m (Predicate ParsedType i)
toPredicate t =
  case targs t of
    (ParsedTypeConstructor i, vars@(_:_)) -> do
      pure (IsIn i vars)
    _ -> unexpected "non-class constraint"

toVar :: Stream s m t1 => ParsedType t -> ParsecT s u m (ParsedType t)
toVar =
  \case
    v@ParsedTypeVariable {} -> pure v
    _ -> unexpected "non-type-variable"

targs :: ParsedType t -> (ParsedType t, [ParsedType t])
targs e = go e []
  where
    go (ParsedTypeApp f x) args = go f (x : args)
    go f args = (f, args)

varfundecl :: TokenParser (ImplicitlyTypedBinding UnkindedType Identifier Location)
varfundecl = go <?> "variable declaration (e.g. x = 1, f = \\x -> x * x)"
  where
    go = do
      (v, loc) <-
         consumeToken
           (\case
              Variable i -> Just i
              _ -> Nothing) <?>
         "variable name"
      _ <- equalToken Equals <?> "‘=’ for variable declaration e.g. x = 1"
      e <- expParser
      _ <- (pure () <* satisfyToken (==NonIndentedNewline)) <|> endOfTokens
      pure (ImplicitlyTypedBinding loc (Identifier (T.unpack v), loc) [makeAlt loc e])

varfundeclExplicit :: TokenParser (Decl UnkindedType Identifier Location)
varfundeclExplicit =
  go <?> "explicitly typed variable declaration (e.g. x :: Int and x = 1)"
  where
    go = do
      (v0, loc) <-
        consumeToken
          (\case
             Variable i -> Just i
             _ -> Nothing) <?>
        "variable name"
      (tok, _) <- anyToken <?> curlyQuotes "::" ++ " or " ++ curlyQuotes "="
      case tok of
        Colons -> do
          scheme <- parseScheme <?> "type signature e.g. foo :: Int"
          _ <- (pure () <* satisfyToken (== NonIndentedNewline)) <|> endOfTokens
          (v, _) <-
            consumeToken
              (\case
                 Variable i -> Just i
                 _ -> Nothing) <?>
            "variable name"
          when
            (v /= v0)
            (unexpected "variable binding name different to the type signature")
          _ <- equalToken Equals <?> "‘=’ for variable declaration e.g. x = 1"
          e <- expParser
          _ <- (pure () <* satisfyToken (== NonIndentedNewline)) <|> endOfTokens
          pure
            (BindDecl
               loc
               (ExplicitBinding
                  (ExplicitlyTypedBinding loc
                     (Identifier (T.unpack v), loc)
                     scheme
                     [makeAlt loc e])))
        Equals -> do
          e <- expParser
          _ <- (pure () <* satisfyToken (== NonIndentedNewline)) <|> endOfTokens
          pure
            (BindDecl
               loc
               (ImplicitBinding
                  (ImplicitlyTypedBinding
                     loc
                     (Identifier (T.unpack v0), loc)
                     [makeAlt loc e])))
        t -> unexpected (tokenStr t)


makeAlt :: l -> Expression t i l -> Alternative t i l
makeAlt loc e =
  case e of
    LambdaExpression _ alt -> alt
    _ -> Alternative loc [] e

case' :: TokenParser (Expression UnkindedType Identifier Location)
case' = do
  u <- getState
  loc <- equalToken Case
  setState (locationStartColumn loc)
  e <- expParser <?> "expression to do case analysis e.g. case e of ..."
  _ <- equalToken Of
  p <- lookAhead altPat <?> "case pattern"
  alts <- many (altParser (Just e) (locationStartColumn (patternLabel p)))
  setState u
  pure (CaseExpression loc e alts)

altsParser
  :: Stream s m (Token, Location)
  => ParsecT s Int m [(CaseAlt UnkindedType Identifier Location)]
altsParser = many (altParser Nothing 1)

altParser
  :: Maybe (Expression UnkindedType Identifier Location)
  -> Int
  -> TokenParser (CaseAlt UnkindedType Identifier Location)
altParser e' startCol =
  (do u <- getState
      p <- altPat
      when
        (locationStartColumn (patternLabel p) /= startCol)
        (unexpected
           ("pattern at column " ++
            show (locationStartColumn (patternLabel p)) ++
            ", it should start at column " ++
            show startCol ++ " to match the others"))
      setState startCol
      _ <- equalToken RightArrow
      e <- expParser
      setState u
      pure (CaseAlt (Location 0 0 0 0) p e)) <?>
  ("case alternative" ++
   (case e' of
      Just eeee ->
        " e.g.\n\ncase " ++
        printExpression defaultPrint eeee ++
        " of\n  Just bar -> bar"
      Nothing -> ""))

altPat :: TokenParser (Pattern UnkindedType Identifier Location)
altPat = bang <|> varp <|> intliteral <|> consParser <|> stringlit
  where
    bang =
      (BangPattern <$>
       (consumeToken
          (\case
             Bang -> Just Bang
             _ -> Nothing) *>
        patInner)) <?> "bang pattern"
    patInner = parenpat <|> varp <|> intliteral <|> unaryConstructor
    parenpat = go
      where
        go = do
          _ <- equalToken OpenParen
          e <- varp <|> altPat
          _ <- equalToken CloseParen <?> "closing parenthesis ‘)’"
          pure e
    intliteral = go <?> "integer (e.g. 42, 123)"
      where
        go = do
          (c, loc) <-
            consumeToken
              (\case
                 Integer c -> Just c
                 _ -> Nothing)
          pure (LiteralPattern loc (IntegerLiteral c))
    stringlit = go <?> "string (e.g. 42, 123)"
      where
        go = do
          (c, loc) <-
            consumeToken
              (\case
                 String c -> Just c
                 _ -> Nothing)
          pure (LiteralPattern loc (StringLiteral (T.unpack c)))
    varp = go <?> "variable pattern (e.g. x)"
      where
        go = do
          (v, loc) <-
            consumeToken
              (\case
                 Variable i -> Just i
                 _ -> Nothing)
          pure
            (if T.isPrefixOf "_" v
               then WildcardPattern loc (T.unpack v)
               else VariablePattern loc (Identifier (T.unpack v)))
    unaryConstructor = go <?> "unary constructor (e.g. Nothing)"
      where
        go = do
          (c, loc) <-
            consumeToken
              (\case
                 Constructor c -> Just c
                 _ -> Nothing)
          pure (ConstructorPattern loc (Identifier (T.unpack c)) [])
    consParser = go <?> "constructor pattern (e.g. Just x)"
      where
        go = do
          (c, loc) <-
            consumeToken
              (\case
                 Constructor c -> Just c
                 _ -> Nothing)
          args <- many patInner
          pure (ConstructorPattern loc (Identifier (T.unpack c)) args)

expParser :: TokenParser (Expression UnkindedType Identifier Location)
expParser = case' <|> lambda <|> ifParser <|> infix' <|> app <|> atomic
  where
    app = do
      left <- funcOp <?> "function expression"
      right <- many unambiguous <?> "function arguments"
      case right of
        [] -> pure left
        _ -> pure (foldl (ApplicationExpression (Location 0 0 0 0)) left right)
    infix' =
      (do left <- (app <|> unambiguous) <?> "left-hand side of operator"
          tok <- fmap Just (operator <?> "infix operator") <|> pure Nothing
          case tok of
            Just (Operator t, _) -> do
              right <-
                (app <|> unambiguous) <?>
                ("right-hand side of " ++
                 curlyQuotes (T.unpack t) ++ " operator")
              badop <- fmap Just (lookAhead operator) <|> pure Nothing
              let infixexp =
                    InfixExpression
                      (Location 0 0 0 0)
                      left
                      (let i = ((T.unpack t))
                       in (i, VariableExpression (Location 0 0 0 0) (Identifier i)))
                      right
              maybe
                (return ())
                (\op ->
                   unexpected
                     (concat
                        [ tokenString op ++
                          ". When more than one operator is used\n"
                        , "in the same expression, use parentheses, like this:\n"
                        , "(" ++
                          printExpression defaultPrint infixexp ++
                          ") " ++
                          (case op of
                             (Operator i, _) -> T.unpack i ++ " ..."
                             _ -> "* ...") ++
                          "\n"
                        , "Or like this:\n"
                        , printExpressionAppArg defaultPrint left ++
                          " " ++
                          T.unpack t ++
                          " (" ++
                          printExpressionAppArg defaultPrint right ++
                          " " ++
                          case op of
                            (Operator i, _) -> T.unpack i ++ " ...)"
                            _ -> "* ...)"
                        ]))
                badop
              pure infixexp
            _ -> pure left) <?>
      "infix expression (e.g. x * y)"
      where
        operator =
          satisfyToken
            (\case
               Operator {} -> True
               _ -> False)
    funcOp = varParser <|> constructorParser <|> parensExpr
    unambiguous = parensExpr <|> atomic
    parensExpr = parens expParser

operatorParser
  :: Stream s m (Token, Location)
  => ParsecT s Int m (String, Expression t Identifier Location)
operatorParser = do
  tok <-
    satisfyToken
      (\case
         Operator {} -> True
         _ -> False)
  pure
    (case tok of
       (Operator t, _) ->
         let i = (T.unpack t)
         in (i, VariableExpression (Location 0 0 0 0) (Identifier i))
       _ -> error "should be operator...")

lambda :: TokenParser (Expression UnkindedType Identifier Location)
lambda = do
  loc <- equalToken Backslash <?> "lambda expression (e.g. \\x -> x)"
  args <- many1 funcParam <?> "lambda parameters"
  _ <- equalToken RightArrow
  e <- expParser
  pure (LambdaExpression loc (Alternative loc args e))

funcParams :: TokenParser [Pattern UnkindedType Identifier Location]
funcParams = many1 funcParam

funcParam :: TokenParser (Pattern UnkindedType Identifier Location)
funcParam = go <?> "function parameter (e.g. ‘x’, ‘limit’, etc.)"
  where
    go = do
      (v, loc) <-
        consumeToken
          (\case
             Variable i -> Just i
             _ -> Nothing)
      pure (VariablePattern loc (Identifier (T.unpack v)))

atomic :: TokenParser (Expression UnkindedType Identifier Location)
atomic =
  varParser <|> charParser <|> stringParser <|> integerParser <|> decimalParser <|>
  constructorParser
  where
    charParser = go <?> "character (e.g. 'a')"
      where
        go = do
          (c, loc) <-
            consumeToken
              (\case
                 Character c -> Just c
                 _ -> Nothing)
          pure (LiteralExpression loc (CharacterLiteral c))
    stringParser = go <?> "string (e.g. \"a\")"
      where
        go = do
          (c, loc) <-
            consumeToken
              (\case
                 String c -> Just c
                 _ -> Nothing)
          pure (LiteralExpression loc (StringLiteral (T.unpack c)))

    integerParser = go <?> "integer (e.g. 42, 123)"
      where
        go = do
          (c, loc) <-
            consumeToken
              (\case
                 Integer c -> Just c
                 _ -> Nothing)
          pure (LiteralExpression loc (IntegerLiteral c))
    decimalParser = go <?> "decimal (e.g. 42, 123)"
      where
        go = do
          (c, loc) <-
            consumeToken
              (\case
                 Decimal c -> Just c
                 _ -> Nothing)
          pure (LiteralExpression loc (RationalLiteral (realToFrac c)))

constructorParser :: TokenParser (Expression UnkindedType Identifier Location)
constructorParser = go <?> "constructor (e.g. Just)"
  where
    go = do
      (c, loc) <-
        consumeToken
          (\case
             Constructor c -> Just c
             _ -> Nothing)
      pure
        (ConstructorExpression loc (Identifier (T.unpack c)))

parens :: TokenParser a -> TokenParser a
parens p = go <?> "parens e.g. (x)"
  where go = do
         _ <- equalToken OpenParen
         e <- p <?> "expression inside parentheses e.g. (foo)"
         _ <- equalToken CloseParen<?> "closing parenthesis ‘)’"
         pure e

varParser :: TokenParser (Expression UnkindedType Identifier Location)
varParser = go <?> "variable (e.g. ‘foo’, ‘id’, etc.)"
  where
    go = do
      (v, loc) <-
        consumeToken
          (\case
             Variable i -> Just i
             _ -> Nothing)
      pure (if T.isPrefixOf "_" v
               then ConstantExpression loc (Identifier (T.unpack v))
               else VariableExpression loc (Identifier (T.unpack v)))

ifParser :: TokenParser (Expression UnkindedType Identifier Location)
ifParser = go <?> "if expression (e.g. ‘if p then x else y’)"
  where
    go = do
      loc <- equalToken If
      p <- expParser <?> "condition expresion of if-expression"
      _ <- equalToken Then <?> "‘then’ keyword for if-expression"
      e1 <- expParser <?> "‘then’ clause of if-expression"
      _ <- equalToken Else <?> "‘else’ keyword for if-expression"
      e2 <- expParser <?> "‘else’ clause of if-expression"
      pure
        (IfExpression
           loc
           { locationEndLine = locationEndLine (expressionLocation loc e2)
           , locationEndColumn = locationEndColumn (expressionLocation loc e2)
           }
           p
           e1
           e2)
    expressionLocation nil e = foldr const nil e


================================================
FILE: src/Duet/Printer.hs
================================================
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE Strict #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-}

-- |

module Duet.Printer where

import           Data.Char
import           Data.List
import qualified Data.Map.Strict as M
import           Duet.Types
import           Text.Printf

class PrintableType (t :: * -> *) where
  printType :: Printable i => Print i l -> SpecialTypes i -> t i -> String

instance PrintableType (Predicate Type) where
  printType = printPredicate

class (Eq a, Identifiable a) => Printable a where
  printit :: Print i l -> a -> String

instance Printable Name where
  printit printer =
    \case
      PrimopName primop -> printPrimop primop
      ValueName i string ->
        string ++
        (if printNameDetails printer
           then "[value:" ++ show i ++ "]"
           else "")
      TypeName i string ->
        string ++
        (if printNameDetails printer
           then "[type:" ++ show i ++ "]"
           else "")
      ConstructorName i string ->
        string ++
        (if printNameDetails printer
           then "[constructor:" ++ show i ++ "]"
           else "")
      ForallName i -> "g" ++ show i
      DictName i string ->
        string ++
        (if printNameDetails printer
           then "[dict:" ++ show i ++ "]"
           else "")
      ClassName i s ->
        s ++
        (if printNameDetails printer
           then "[class:" ++ show i ++ "]"
           else "")
      MethodName i s ->
        s ++
        (if printNameDetails printer
           then "[method:" ++ show i ++ "]"
           else "")

printPrimop :: Primop -> [Char]
printPrimop =
  \case
    PrimopIntegerSubtract -> "subtract"
    PrimopIntegerTimes -> "times"
    PrimopIntegerPlus -> "plus"
    PrimopRationalSubtract -> "subtract"
    PrimopRationalTimes -> "times"
    PrimopRationalPlus -> "plus"
    PrimopRationalDivide -> "divide"
    PrimopStringAppend -> "append"
    PrimopStringDrop -> "drop"
    PrimopStringTake -> "take"

instance Printable Identifier where
  printit _ =
    \case
      Identifier string -> string

defaultPrint :: Print i b
defaultPrint =
  Print
  { printDictionaries = False
  , printTypes = const Nothing
  , printNameDetails = False
  }

data Print i l = Print
  { printTypes :: (l -> Maybe (SpecialTypes i, TypeSignature Type i ()))
  , printDictionaries :: Bool
  , printNameDetails :: Bool
  }

printDataType :: (Printable i, PrintableType t) => Print i l -> SpecialTypes i -> DataType t i -> String
printDataType printer specialTypes (DataType name vars cons) =
  "data " ++ printit printer name ++ " " ++ unwords (map (printTypeVariable printer) vars) ++ "\n  = " ++
    intercalate "\n  | " (map (printConstructor printer specialTypes) cons)

printConstructor :: (Printable i, PrintableType t) => Print i l ->  SpecialTypes i -> DataTypeConstructor t i -> [Char]
printConstructor printer specialTypes (DataTypeConstructor name fields) =
  printit printer name ++ " " ++ unwords (map (printType printer specialTypes) fields)

printTypeSignature
  :: (Printable i, Printable j)
  => Print i l ->  SpecialTypes i -> TypeSignature Type i j -> String
printTypeSignature printer specialTypes (TypeSignature thing scheme) =
  printit printer thing ++ " :: " ++ printScheme printer specialTypes scheme

printIdentifier :: Printable j => Print i l ->  j -> String
printIdentifier printer = printit printer

printImplicitlyTypedBinding
  :: (Printable i, PrintableType t)
  => Print i l -> ImplicitlyTypedBinding t i l -> String
printImplicitlyTypedBinding printer (ImplicitlyTypedBinding _ (i, _) [alt]) =
  printIdentifier printer i ++ " " ++ printAlternative printer alt
printImplicitlyTypedBinding _ _ = ""

printExplicitlyTypedBinding
  :: (Printable i, PrintableType t)
  => Print i l -> SpecialTypes i -> ExplicitlyTypedBinding t i l -> String
printExplicitlyTypedBinding printer specialTypes (ExplicitlyTypedBinding _ (i, _) scheme [alt]) =
  printIdentifier printer i ++ " :: " ++ printScheme printer specialTypes scheme ++ "\n" ++
  printIdentifier printer i ++ " " ++ printAlternative printer alt
printExplicitlyTypedBinding _ _ _ = ""

printAlternative :: (Printable i, PrintableType t) => Print i l -> Alternative t i l -> [Char]
printAlternative printer (Alternative _ patterns expression) =
  concat (map (\p->printPattern printer p ++ " ") patterns) ++ "= " ++ printExpression printer expression

printPattern :: (Printable i, PrintableType t) => Print i l ->  Pattern t i l -> [Char]
printPattern printer =
  \case
    BangPattern p -> "!" ++ printPattern printer p
    VariablePattern _ i -> printIdentifier printer i
    WildcardPattern _ s -> s
    AsPattern _ i p -> printIdentifier printer i ++ "@" ++ printPattern printer p
    LiteralPattern _ l -> printLiteral l
    ConstructorPattern _ i pats ->
      printIdentifier printer i ++ " " ++ unwords (map (printPattern printer) pats)

printExpression :: (Printable i, PrintableType t) => Print i l -> (Expression t i l) -> String
printExpression printer e =
  wrapType
    (case e of
       LiteralExpression _ l -> printLiteral l
       VariableExpression _ i -> printIdentifier printer i
       ConstantExpression _ i -> printIdentifier printer i
       ConstructorExpression _ i -> printIdentifier printer i
       ParensExpression _ e -> "(" <> (printExpression printer e) <> ")"
       CaseExpression _ e alts ->
         "case " ++
         indent 5 (printExpressionIfPred printer e) ++
         " of\n" ++ indented (intercalate "\n" (map (printAlt printer) alts))
       ApplicationExpression _ f x ->
         case x of
           VariableExpression _ (nonrenamableName -> Just (DictName {}))
             | not (printDictionaries printer) -> printExpressionAppOp printer f
           _ ->
             if any (== '\n') inner || any (== '\n') prefix
               then prefix ++ "\n" ++ indented inner
               else prefix ++ " " ++ indent (length prefix + 1) inner
             where prefix = printExpressionAppOp printer f
                   inner = printExpressionAppArg printer x
       LambdaExpression _ (Alternative _ args e) ->
         if null filteredArgs
           then inner
           else if any (== '\n') inner
                  then "\\" ++ prefix ++ "->\n" ++ indented inner
                  else "\\" ++
                       prefix ++ "-> " ++ indent (length prefix + 4) inner
         where inner = (printExpression printer e)
               filteredArgs = filter dictPred args
               prefix =
                 concat (map (\x -> printPattern printer x ++ " ") filteredArgs)
               dictPred =
                 if printDictionaries printer
                   then const True
                   else \case
                          VariablePattern _ (nonrenamableName -> Just (DictName {})) ->
                            False
                          _ -> True
       IfExpression _ a b c ->
         "if " ++
         printExpressionIfPred printer a ++
         " then " ++
         printExpression printer b ++ " else " ++ printExpression printer c
       InfixExpression _ f (o, ov) x ->
         printExpressionAppArg printer f ++
         " " ++
         (if printDictionaries printer
            then "`" ++ printExpression printer ov ++ "`"
            else o) ++
         " " ++ printExpressionAppArg printer x
       _ -> "<TODO>")
  where
    wrapType x =
      case printTypes printer (expressionLabel e) of
        (Nothing) -> x
        (Just (specialTypes, TypeSignature _ ty)) ->
          "(" ++
          parens x ++ " :: " ++ printScheme printer specialTypes ty ++ ")"
          where parens k =
                  if any isSpace k
                    then "(" ++ k ++ ")"
                    else k

printAlt
  :: (PrintableType t, Printable i)
  => Print i l -> (CaseAlt t i l) -> [Char]
printAlt printer =
  \(CaseAlt _ p e') ->
    let inner = printExpression printer e'
    in if any (== '\n') inner
         then printPat printer p ++ " ->\n" ++ indented inner
         else printPat printer p ++ " -> " ++ indent 2 inner

indented :: String -> [Char]
indented x = intercalate "\n" (map ("  "++) (lines x))

indent :: Int -> String -> [Char]
indent n = intercalate ("\n" ++ replicate n ' ') . lines

lined :: [[Char]] -> [Char]
lined = intercalate "\n  "

printPat :: (Printable i, PrintableType t) => Print i l ->  Pattern t i l -> String
printPat printer=
  \case
    BangPattern p -> "!" ++ printPat printer p
    VariablePattern _ i -> printit printer i
    ConstructorPattern _ i ps ->
      printit printer i ++
      (if null ps
         then ""
         else " " ++ unwords (map inner ps))
    WildcardPattern{} -> "_"
    AsPattern _ ident p -> printit printer ident ++ "@" ++ printPat printer p
    LiteralPattern _ l -> printLiteral l
  where
    inner =
      \case
        BangPattern p -> "!" ++ inner p
        VariablePattern _ i -> printit printer i
        WildcardPattern _ s -> s
        ConstructorPattern _ i ps
          | null ps -> printit printer i
          | otherwise ->
            "(" ++ printit printer i ++ " " ++ unwords (map inner ps) ++ ")"
        AsPattern _ ident p -> printit printer ident ++ "@" ++ printPat printer p
        LiteralPattern _ l -> printLiteral l

printExpressionAppArg :: (Printable i, PrintableType t) => Print i l ->(Expression t i l) -> String
printExpressionAppArg printer =
  \case
    e@(ApplicationExpression {})
      | nodict e -> paren (printExpression printer e)
    e@(IfExpression {}) -> paren (printExpression printer e)
    e@(InfixExpression {}) -> paren (printExpression printer e)
    e@(LambdaExpression {}) -> paren (printExpression printer e)
    e@(CaseExpression {}) -> paren (printExpression printer e)
    e -> printExpression printer e
  where
    nodict =
      \case
        ApplicationExpression _ _ (VariableExpression _ (nonrenamableName -> Just (DictName {})))
          | not (printDictionaries printer) -> False
        _ -> True

printExpressionIfPred :: (Printable i, PrintableType t) => Print i l -> (Expression t i l) -> String
printExpressionIfPred printer=
  \case
    e@(IfExpression {}) -> paren (printExpression printer e)
    e@(LambdaExpression {}) -> paren (printExpression printer e)
    e@(CaseExpression {}) -> paren (printExpression printer e)
    e -> printExpression printer e

printExpressionAppOp :: (Printable i, PrintableType t) => Print i l -> (Expression t i l) -> String
printExpressionAppOp printer=
  \case
    e@(IfExpression {}) -> paren (printExpression printer e)
    e@(LambdaExpression {}) -> paren (printExpression printer e)
    e@(CaseExpression {}) -> paren (printExpression printer e)
    e -> printExpression printer e

paren :: [Char] -> [Char]
paren e = "("  ++ indent 1 e ++ ")"

printLiteral :: Literal -> String
printLiteral (IntegerLiteral i) = show i
printLiteral (RationalLiteral i) = printf "%f" (fromRational i :: Double)
printLiteral (StringLiteral x) = show x
printLiteral (CharacterLiteral x) = show x

printScheme :: (Printable i, PrintableType t, PrintableType t1) => Print i l -> SpecialTypes i -> Scheme t i t1 -> [Char]
printScheme printer specialTypes (Forall kinds qualifiedType') =
  (if null kinds
     then ""
     else "forall " ++
          unwords
            (zipWith
               (\_i k ->
                  printTypeVariable
                    (Print
                     { printTypes = const Nothing
                     , printDictionaries = False
                     , printNameDetails = printNameDetails printer
                     })
                    k)
               [0 :: Int ..]
               kinds) ++
          ". ") ++
  printQualifiedType specialTypes qualifiedType'
  where
    printQualifiedType specialTypes (Qualified predicates typ) =
      case predicates of
        [] -> printType printer specialTypes typ
        _ ->
          "(" ++
          intercalate
            ", "
            (map (printPredicate printer specialTypes) predicates) ++
          ") => " ++ printType printer specialTypes typ


printClass :: Printable i => Print i l -> SpecialTypes i -> Class Type i l -> String
printClass printer specialTypes (Class vars supers instances i methods) =
  "class " ++
  printSupers printer specialTypes supers ++
  printit printer i ++
  " " ++
  unwords (map (printTypeVariable printer) vars) ++ " where\n  " ++
  intercalate "\n  " (map (printMethod printer specialTypes) (M.toList methods)) ++
  "\n" ++ intercalate "\n" (map (printInstance printer specialTypes) instances)

printMethod :: Printable i =>  Print i l -> SpecialTypes i -> (i, Scheme Type i Type) -> String
printMethod printer specialTypes (i, scheme) =
  printit printer i ++ " :: " ++ printScheme printer specialTypes scheme

printInstance :: Printable i => Print i l -> SpecialTypes i -> Instance Type i l -> String
printInstance printer specialTypes (Instance scheme _) =
  "instance " ++
  printScheme printer specialTypes scheme

printSupers :: Printable i => Print i l -> SpecialTypes i -> [Predicate Type i] -> [Char]
printSupers printer specialTypes supers
  | null supers = ""
  | otherwise =
    "(" ++ intercalate ", " (map (printPredicate printer specialTypes) supers) ++ ") => "


printPredicate :: (Printable i, PrintableType t) => Print i l -> SpecialTypes i -> Predicate t i -> [Char]
printPredicate printer specialTypes (IsIn identifier types) =
  printIdentifier printer identifier ++
  " " ++ unwords (map (wrap . printType printer specialTypes) types)
  where wrap x = if any isSpace x
                    then "(" ++ x ++ ")"
                    else x

printKind :: Kind -> [Char]
printKind =
  \case
    StarKind -> "Type"
    FunctionKind x' y -> printKind x' ++ " -> " ++ printKind y

printTypeSansParens :: (Printable i) => Print i l ->  SpecialTypes i -> Type i -> [Char]
printTypeSansParens printer specialTypes =
  \case
    ApplicationType (ApplicationType func x') y'
      | func == ConstructorType (specialTypesFunction specialTypes) ->
        printType printer specialTypes x' ++
        " -> " ++ printTypeSansParens printer specialTypes y'
    o -> printType printer specialTypes o

instance PrintableType Type where
  printType printer specialTypes =
    \case
      VariableType v -> printTypeVariable printer v
      ConstructorType tyCon -> printTypeConstructor printer tyCon
      ApplicationType (ApplicationType func x') y
        | func == ConstructorType (specialTypesFunction specialTypes) ->
          "(" ++
          printType printer specialTypes x' ++
          " -> " ++ printTypeSansParens printer specialTypes y ++ ")"
    -- ApplicationType list ty | list == specialTypesList specialTypes ->
    --   "[" ++ printTypeSansParens specialTypes ty ++ "]"
      ApplicationType x' y ->
        printType printer specialTypes x' ++ " " ++ printTypeArg y
      -- GenericType int -> "g" ++ show int
    where
      printTypeArg =
        \case
          x@ApplicationType {} -> "(" ++ printType printer specialTypes x ++ ")"
          x -> printType printer specialTypes x

instance PrintableType UnkindedType where
  printType printer specialTypes =
    \case
      UnkindedTypeVariable v -> printIdentifier printer v
      UnkindedTypeConstructor tyCon -> printIdentifier printer tyCon
      UnkindedTypeApp x' y ->
        "(" ++ printType printer specialTypes x' ++ " " ++ printType printer specialTypes y ++ ")"

printTypeConstructor :: Printable j => Print i l -> TypeConstructor j -> String
printTypeConstructor printer (TypeConstructor identifier kind) =
  case kind of
    StarKind -> printIdentifier printer identifier
    FunctionKind {} -> printIdentifier printer identifier
        -- _ -> "(" ++ printIdentifier identifier ++ " :: " ++ printKind kind ++ ")"

printTypeVariable :: Printable i => Print i l -> TypeVariable i -> String
printTypeVariable printer (TypeVariable identifier kind) =
  case kind of
    StarKind -> printIdentifier printer identifier
    _ -> "(" ++ printIdentifier printer identifier ++ " :: " ++ printKind kind ++ ")"

curlyQuotes :: [Char] -> [Char]
curlyQuotes t = "‘" <> t <> "’"


================================================
FILE: src/Duet/Renamer.hs
================================================
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE FlexibleContexts #-}

-- At each binding point (lambdas), we need to supply a new unique
-- name, and then rename everything inside the expression.
--
-- For each BindGroup, we should generate the list of unique names
-- first for each top-level thing (which might be mutually
-- independent), and then run the sub-renaming processes, with the new
-- substitutions in scope.
--
-- It's as simple as that.

module Duet.Renamer
  ( renameDataTypes
  , renameBindings
  , renameBindGroups
  , renameExpression
  , renameClass
  , renameInstance
  , predicateToDict
  , operatorTable
  , Specials(Specials)
  ) where

import           Control.Arrow
import           Control.Monad.Catch
import           Control.Monad.Supply
import           Control.Monad.Trans
import           Control.Monad.Writer
import           Data.Char
import           Data.List
import           Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import           Data.Maybe
import           Duet.Infer
import           Duet.Printer
import           Duet.Supply
import           Duet.Types

--------------------------------------------------------------------------------
-- Data type renaming (this includes kind checking)

renameDataTypes
  :: (MonadSupply Int m, MonadThrow m)
  => Specials Name
  -> [DataType UnkindedType Identifier]
  -> m [DataType Type Name]
renameDataTypes specials types = do
  typeConstructors <-
    mapM
      (\(DataType name vars cs) -> do
         name' <- supplyTypeName name
         vars' <-
           mapM
             (\(TypeVariable i k) -> do
                i' <- supplyTypeName i
                pure (i, TypeVariable i' k))
             vars
         pure (name, name', vars', cs))
      types
  mapM
    (\(_, name, vars, cs) -> do
       cs' <- mapM (renameConstructor specials typeConstructors vars) cs
       pure (DataType name (map snd vars) cs'))
    typeConstructors

renameConstructor
  :: (MonadSupply Int m, MonadThrow m)
  => Specials Name -> [(Identifier, Name, [(Identifier, TypeVariable Name)], [DataTypeConstructor UnkindedType Identifier])]
  -> [(Identifier, TypeVariable Name)]
  -> DataTypeConstructor UnkindedType Identifier
  -> m (DataTypeConstructor Type Name)
renameConstructor specials typeConstructors vars (DataTypeConstructor name fields) = do
  name' <- supplyConstructorName name
  fields' <- mapM (renameField specials typeConstructors vars name') fields
  pure (DataTypeConstructor name' fields')

renameField
  :: (MonadThrow m, MonadSupply Int m)
  => Specials Name
  -> [(Identifier, Name, [(Identifier, TypeVariable Name)], [DataTypeConstructor UnkindedType Identifier])]
  -> [(Identifier, TypeVariable Name)]
  -> Name
  -> UnkindedType Identifier
  -> m (Type Name)
renameField specials typeConstructors vars name fe = do
  ty <- go fe
  if typeKind ty == StarKind
    then pure ty
    else throwM (ConstructorFieldKind name ty (typeKind ty))
  where
    go =
      \case
        UnkindedTypeConstructor i -> do
          (name', vars') <- resolve i
          pure (ConstructorType (toTypeConstructor name' (map snd vars')))
        UnkindedTypeVariable v ->
          case lookup v vars of
            Nothing -> throwM (UnknownTypeVariable (map snd vars) v)
            Just tyvar -> pure (VariableType tyvar)
        UnkindedTypeApp f x -> do
          f' <- go f
          let fKind = typeKind f'
          case fKind of
            FunctionKind argKind _ -> do
              x' <- go x
              let xKind = typeKind x'
              if xKind == argKind
                then pure (ApplicationType f' x')
                else throwM (KindArgMismatch f' fKind x' xKind)
            StarKind -> do
              x' <- go x
              throwM (KindTooManyArgs f' fKind x')
    resolve i =
      case find ((\(j, _, _, _) -> j == i)) typeConstructors of
        Just (_, name', vs, _) -> pure (name', vs)
        Nothing ->
          case specialTypesFunction (specialsTypes specials) of
            TypeConstructor n@(TypeName _ i') _
              | Identifier i' == i -> do
                fvars <-
                  mapM
                    (\vari ->
                       (vari, ) <$>
                       fmap
                         (\varn -> TypeVariable varn StarKind)
                         (supplyTypeVariableName vari))
                    (map Identifier ["a", "b"])
                pure (n, fvars)
            _ ->
              case listToMaybe (mapMaybe (matches i) builtinStarTypes) of
                Just ty -> pure ty
                Nothing ->
                  case find
                         (\case
                            TypeName _ tyi -> Identifier tyi == i
                            _ -> False)
                         (map
                            typeConstructorIdentifier
                            [ specialTypesChar (specialsTypes specials)
                            , specialTypesInteger (specialsTypes specials)
                            , specialTypesRational (specialsTypes specials)
                            , specialTypesString (specialsTypes specials)
                            ]) of
                    Just ty -> pure (ty, [])
                    _ -> throwM (TypeNotInScope [] i)
    matches i t =
      case t of
        DataType n@(TypeName _ i') vs _
          | Identifier i' == i ->
            Just
              ( n
              , mapMaybe
                  (\case
                     (TypeVariable n'@(TypeName _ tyi) k) ->
                       Just (Identifier tyi, TypeVariable n' k)
                     _ -> Nothing)
                  vs)
        _ -> Nothing
    builtinStarTypes = [specialTypesBool (specialsTypes specials)]

--------------------------------------------------------------------------------
-- Class renaming

renameClass
  :: forall m.
     (MonadSupply Int m, MonadThrow m)
  => Specials Name
  -> Map Identifier Name
  -> [DataType Type Name]
  -> Class UnkindedType Identifier Location
  -> m (Class Type Name Location)
renameClass specials subs types cls = do
  name <- supplyClassName (className cls)
  classVars <-
    mapM
      (\(TypeVariable i k) -> do
         i' <- supplyTypeName i
         pure (i, TypeVariable i' k))
      (classTypeVariables cls)
  instances <-
    mapM
      (renameInstance' specials subs types classVars)
      (classInstances cls)
  methods' <-
    fmap
      M.fromList
      (mapM
         (\(mname, (Forall vars (Qualified preds ty))) -> do
            name' <- supplyMethodName mname
            methodVars <- mapM (renameMethodTyVar classVars) vars
            let classAndMethodVars = nub (classVars ++ methodVars)
            ty' <- renameType specials classAndMethodVars types ty
            preds' <-
              mapM
                (\(IsIn c tys) ->
                   IsIn <$> substituteClass subs c <*>
                   mapM (renameType specials classAndMethodVars types) tys)
                preds
            pure
              ( name'
              , (Forall (map snd classAndMethodVars) (Qualified preds' ty'))))
         (M.toList (classMethods cls)))
  pure
    (Class
     { className = name
     , classTypeVariables = map snd classVars
     , classSuperclasses = []
     , classInstances = instances
     , classMethods = methods'
     })
  where
    renameMethodTyVar
      :: [(Identifier, TypeVariable Name)]
      -> TypeVariable Identifier
      -> m (Identifier, TypeVariable Name)
    renameMethodTyVar classTable (TypeVariable ident k) =
      case lookup ident classTable of
        Nothing -> do
          i' <- supplyTypeName ident
          pure (ident, TypeVariable i' k)
        Just v -> pure (ident, v)

--------------------------------------------------------------------------------
-- Instance renaming

renameInstance
  :: (MonadThrow m, MonadSupply Int m)
  => Specials Name
  -> Map Identifier Name
  -> [DataType Type Name]
  -> [Class Type Name l]
  -> Instance UnkindedType Identifier Location
  -> m (Instance Type Name Location)
renameInstance specials subs types classes inst@(Instance (Forall _ (Qualified _ (IsIn className' _))) _) = do
  {-trace ("renameInstance: Classes: " ++ show (map className classes)) (return ())-}
  table <- mapM (\c -> fmap (, c) (identifyClass (className c))) classes
  {-trace ("renameInstance: Table: " ++ show table) (return ())-}
  case lookup className' table of
    Nothing ->
      do {-trace ("renameInstance: ???" ++ show className') (return ())-}
         throwM
           (IdentifierNotInClassScope
              (M.fromList (map (second className) table))
              className')
    Just typeClass -> do
      vars <-
        mapM
          (\v@(TypeVariable i _) -> fmap (, v) (identifyType i))
          (classTypeVariables typeClass)
      instr <- renameInstance' specials subs types vars inst
      pure instr

renameInstance'
  :: (MonadThrow m, MonadSupply Int m)
  => Specials Name
  -> Map Identifier Name
  -> [DataType Type Name]
  -> [(Identifier, TypeVariable Name)]
  -> Instance UnkindedType Identifier Location
  -> m (Instance Type Name Location)
renameInstance' specials subs types _tyVars (Instance (Forall vars (Qualified preds ty)) dict) = do
  let vars0 =
        nub
          (if null vars
              then concat
                     (map
                        collectTypeVariables
                        (case ty of
                           IsIn _ t -> t))
              else vars)
  vars'' <-
    mapM
      (\(TypeVariable i k) -> do
         n <- supplyTypeName i
         pure (i, TypeVariable n k))
      vars0
  preds' <- mapM (renamePredicate specials subs vars'' types) preds
  ty' <- renamePredicate specials subs vars'' types ty
  dict' <- renameDict specials subs types dict  ty'
  pure (Instance (Forall (map snd vars'') (Qualified preds' ty')) dict')
  where
    collectTypeVariables :: UnkindedType i -> [TypeVariable i]
    collectTypeVariables =
      \case
        UnkindedTypeConstructor {} -> []
        UnkindedTypeVariable i -> [TypeVariable i StarKind]
        UnkindedTypeApp f x -> collectTypeVariables f ++ collectTypeVariables x

renameDict
  :: (MonadThrow m, MonadSupply Int m)
  => Specials Name
  -> Map Identifier Name
  -> [DataType Type Name]
  -> Dictionary UnkindedType Identifier Location
  -> Predicate Type Name
  -> m (Dictionary Type Name Location)
renameDict specials subs types (Dictionary _ methods) predicate = do
  name' <-
    supplyDictName'
      (Identifier (predicateToDict specials predicate))
  methods' <-
    fmap
      M.fromList
      (mapM
         (\(n, (l, alt)) -> do
            n' <- supplyMethodName n
            alt' <- renameAlt specials subs  types alt
            pure (n', (l, alt')))
         (M.toList methods))
  pure (Dictionary name' methods')

predicateToDict :: Specials Name -> ((Predicate Type Name)) -> String
predicateToDict specials p =
  "$dict" ++ map normalize (printPredicate defaultPrint (specialsTypes specials) p)
  where
    normalize c
      | isDigit c || isLetter c = c
      | otherwise = '_'


renamePredicate
  :: (MonadThrow m, Typish (t i), Identifiable i)
  => Specials Name
  -> Map Identifier Name
  -> [(Identifier, TypeVariable Name)]
  -> [DataType Type Name]
  -> Predicate t i
  -> m (Predicate Type Name)
renamePredicate specials subs tyVars types (IsIn className' types0) =
  do subbedClassName <- substituteClass subs className'
     types' <- mapM (renameType specials tyVars types -- >=> forceStarKind
                    ) types0
     pure (IsIn subbedClassName types')

-- | Force that the type has kind *.
_forceStarKind :: MonadThrow m => Type Name -> m (Type Name)
_forceStarKind ty =
  case typeKind ty of
    StarKind -> pure ty
    _ -> throwM (MustBeStarKind ty (typeKind ty))

renameScheme
  :: (MonadSupply Int m, MonadThrow m, Identifiable i, Typish (t i))
  => Specials Name
  -> Map Identifier Name
  -> [DataType Type Name]
  -> Scheme t i t
  -> m (Scheme Type Name Type)
renameScheme specials subs  types (Forall tyvars (Qualified ps ty)) = do
  tyvars' <-
    mapM
      (\(TypeVariable i kind) -> do
         do n <-
              case nonrenamableName i of
                Just k -> pure k
                Nothing -> do
                  i' <- identifyType i
                  supplyTypeName i'
            ident <- identifyType n
            (ident, ) <$> (TypeVariable <$> pure n <*> pure kind))
      tyvars
  ps'  <- mapM (renamePredicate specials subs tyvars' types) ps
  ty' <- renameType specials tyvars' types ty
  pure (Forall (map snd tyvars') (Qualified ps' ty'))

-- | Rename a type, checking kinds, taking names, etc.
renameType
  :: (MonadThrow m, Typish (t i))
  => Specials Name
  -> [(Identifier, TypeVariable Name)]
  -> [DataType Type Name]
  -> t i
  -> m (Type Name)
renameType specials tyVars types t = either go pure (isType t)
  where
    go =
      \case
        UnkindedTypeConstructor i -> do
          ms <- mapM (\p -> fmap (, p) (identifyType (dataTypeName p))) types
          case lookup i ms of
            Nothing -> do
              do specials'' <- sequence specials'
                 case lookup i specials'' of
                   Nothing ->
                     throwM
                       (TypeNotInScope
                          (map dataTypeToConstructor (map snd ms))
                          i)
                   Just t' -> pure (ConstructorType t')
            Just dty -> pure (dataTypeConstructor dty)
        UnkindedTypeVariable i -> do
          case lookup i tyVars of
            Nothing -> throwM (UnknownTypeVariable (map snd tyVars) i)
            Just ty -> do
              pure (VariableType ty)
        UnkindedTypeApp f a -> do
          f' <- go f
          case typeKind f' of
            FunctionKind argKind _ -> do
              a' <- go a
              if typeKind a' == argKind
                then pure (ApplicationType f' a')
                else throwM (KindArgMismatch f' (typeKind f') a' (typeKind a'))
            StarKind -> do
              a' <- go a
              throwM (KindTooManyArgs f' (typeKind f') a')
    specials' =
      [ setup (specialTypesFunction . specialsTypes)
      , setup (specialTypesInteger . specialsTypes)
      , setup (specialTypesChar . specialsTypes)
      , setup (specialTypesRational . specialsTypes)
      , setup (specialTypesString . specialsTypes)
      , setup (dataTypeToConstructor . specialTypesBool . specialsTypes)
      ]
      where
        setup f = do
          i <- identifyType (typeConstructorIdentifier (f specials))
          pure (i, f specials)

--------------------------------------------------------------------------------
-- Value renaming

renameBindGroups
  :: ( MonadSupply Int m
     , MonadThrow m
     , Ord i
     , Identifiable i
     , Typish (UnkindedType i)
     )
  => Specials Name
  -> Map Identifier Name
  -> [DataType Type Name]
  -> [BindGroup UnkindedType i Location]
  -> m ([BindGroup Type Name Location], Map Identifier Name)
renameBindGroups specials subs types groups = do
  subs' <-
    fmap
      mconcat
      (mapM
         (\(BindGroup explicit implicit) -> do
            implicit' <- getImplicitSubs subs implicit
            explicit' <- getExplicitSubs subs explicit
            pure (explicit' <> implicit'))
         groups)
  fmap
    (second mconcat . unzip)
    (mapM (renameBindGroup specials subs' types) groups)

renameBindings
  :: (MonadSupply Int m, MonadThrow m, Ord i, Identifiable i, Typish (t i))
  => Specials Name
  -> Map Identifier Name
  -> [DataType Type Name]
  -> [Binding t i Location]
  -> m ([Binding Type Name Location], Map Identifier Name)
renameBindings specials subs types bindings = do
  subs' <-
    fmap
      ((<> subs) . M.fromList)
      (mapM
         (\case
            ExplicitBinding (ExplicitlyTypedBinding _ (i, _) _ _) -> do
              v <- identifyValue i
              fmap (v, ) (supplyValueName i)
            ImplicitBinding (ImplicitlyTypedBinding _ (i, _) _) -> do
              v <- identifyValue i
              fmap (v, ) (supplyValueName i))
         bindings)
  bindings' <-
    mapM
      (\case
         ExplicitBinding e ->
           ExplicitBinding <$> renameExplicit specials subs' types e
         ImplicitBinding i ->
           ImplicitBinding <$> renameImplicit specials subs' types i)
      bindings
  pure (bindings', subs')

renameBindGroup
  :: (MonadSupply Int m, MonadThrow m, Ord i, Identifiable i, Typish (t i))
  => Specials Name
  -> Map Identifier Name
  -> [DataType Type Name]
  -> BindGroup t i Location
  -> m (BindGroup Type Name Location, Map Identifier Name)
renameBindGroup  specials subs  types (BindGroup explicit implicit) = do
  bindGroup' <-
    BindGroup <$> mapM (renameExplicit specials subs  types) explicit <*>
    mapM (mapM (renameImplicit specials subs  types)) implicit
  pure (bindGroup', subs)

getImplicitSubs
  :: (MonadSupply Int m, Identifiable i, MonadThrow m)
  => Map Identifier Name
  -> [[ImplicitlyTypedBinding t i l]]
  -> m (Map Identifier Name)
getImplicitSubs subs implicit =
  fmap
    ((<> subs) . M.fromList)
    (mapM
       (\(ImplicitlyTypedBinding _ (i, _) _) -> do
          v <- identifyValue i
          fmap (v, ) (supplyValueName i))
       (concat implicit))

getExplicitSubs
  :: (MonadSupply Int m, Identifiable i, MonadThrow m)
  => Map Identifier Name
  -> [ExplicitlyTypedBinding t i l]
  -> m (Map Identifier Name)
getExplicitSubs subs explicit =
  fmap
    ((<> subs) . M.fromList)
    (mapM
       (\(ExplicitlyTypedBinding _ (i, _) _ _) -> do
          v <- identifyValue i
          fmap (v, ) (supplyValueName i))
       explicit)

renameExplicit
  :: (MonadSupply Int m, MonadThrow m, Identifiable i, Ord i, Typish (t i))
  => Specials Name
  -> Map Identifier Name
  -> [DataType Type Name]
  -> ExplicitlyTypedBinding t i Location
  -> m (ExplicitlyTypedBinding Type Name Location)
renameExplicit specials subs  types (ExplicitlyTypedBinding l (i, l') scheme alts) = do
  name <- substituteVar subs i l'
  ExplicitlyTypedBinding l (name, l') <$> renameScheme specials subs  types scheme <*>
    mapM (renameAlt specials subs  types) alts

renameImplicit
  :: (MonadThrow m,MonadSupply Int m,Ord i, Identifiable i, Typish (t i))
  => Specials Name
       -> Map Identifier Name
       -> [DataType Type Name]
  -> ImplicitlyTypedBinding t i Location
  -> m (ImplicitlyTypedBinding Type Name Location)
renameImplicit specials subs types (ImplicitlyTypedBinding l (id',l') alts) =
  do name <- substituteVar subs id' l'
     ImplicitlyTypedBinding l (name, l') <$> mapM (renameAlt specials subs types) alts

renameAlt ::
     (MonadSupply Int m, MonadThrow m, Ord i, Identifiable i, Typish (t i))
  => Specials Name
  -> Map Identifier Name
  -> [DataType Type Name]
  -> Alternative t i Location
  -> m (Alternative Type Name Location)
renameAlt specials subs types (Alternative l ps e) =
  do (ps', subs') <- runWriterT (mapM (renamePattern subs) ps)
     let subs'' = M.fromList subs' <> subs
     Alternative l <$> pure ps' <*> renameExpression specials subs'' types e

renamePattern
  :: (MonadSupply Int m, MonadThrow m, Ord i, Identifiable i)
  => Map Identifier Name
  -> Pattern t i l
  -> WriterT [(Identifier, Name)] m (Pattern Type Name l)
renamePattern subs =
  \case
    BangPattern p -> fmap BangPattern (renamePattern subs p)
    VariablePattern l i -> do
      name <- maybe (lift (supplyValueName i)) pure (nonrenamableName i)
      v <- identifyValue i
      tell [(v, name)]
      pure (VariablePattern l name)
    WildcardPattern l s -> pure (WildcardPattern l s)
    AsPattern l i p -> do
      name <- supplyValueName i
      v <- identifyValue i
      tell [(v, name)]
      AsPattern l name <$> renamePattern subs p
    LiteralPattern l0 l -> pure (LiteralPattern l0 l)
    ConstructorPattern l i pats ->
      ConstructorPattern l <$> substituteCons subs i <*>
      mapM (renamePattern subs) pats

class Typish t where isType :: t -> Either (UnkindedType Identifier) (Type Name)
instance Typish (Type Name) where isType = Right
instance Typish (UnkindedType Identifier) where isType = Left

renameExpression
  :: forall t i m.
     (MonadThrow m, MonadSupply Int m, Ord i, Identifiable i, Typish (t i))
  => Specials Name
  -> Map Identifier Name
  -> [DataType Type Name]
  -> Expression t i Location
  -> m (Expression Type Name Location)
renameExpression specials subs types = go
  where
    go :: Expression t i Location -> m (Expression Type Name Location)
    go =
      \case
        ParensExpression l e -> ParensExpression l <$> go e
        VariableExpression l i -> VariableExpression l <$> substituteVar subs i l
        ConstructorExpression l i ->
          ConstructorExpression l <$> substituteCons subs i
        ConstantExpression l i -> pure (ConstantExpression l i)
        LiteralExpression l i -> pure (LiteralExpression l i)
        ApplicationExpression l f x -> ApplicationExpression l <$> go f <*> go x
        InfixExpression l x (orig, VariableExpression l0 i) y -> do
          i' <-
            case nonrenamableName i of
              Just nr -> pure nr
              Nothing -> do
                ident <- identifyValue i
                case lookup ident operatorTable of
                  Just f -> pure (f (specialsSigs specials))
                  _ -> throwM (IdentifierNotInVarScope subs ident l0)
          InfixExpression l <$> go x <*> pure (orig, VariableExpression l0 i') <*>
            go y
        InfixExpression l x (orig, o) y ->
          InfixExpression l <$> go x <*> fmap (orig,) (go o) <*> go y
        LetExpression l bindGroup@(BindGroup ex implicit) e -> do
          subs0 <- getImplicitSubs subs implicit
          subs1 <- getExplicitSubs subs ex
          (bindGroup', subs'') <-
            renameBindGroup specials (subs0 <> subs1) types bindGroup
          LetExpression l <$> pure bindGroup' <*>
            renameExpression specials subs'' types e
        LambdaExpression l alt ->
          LambdaExpression l <$> renameAlt specials subs types alt
        IfExpression l x y z -> IfExpression l <$> go x <*> go y <*> go z
        CaseExpression l e pat_exps ->
          CaseExpression l <$> go e <*>
          mapM
            (\(CaseAlt l1 pat ex) -> do
               (pat', subs') <- runWriterT (renamePattern subs pat)
               e' <-
                 renameExpression specials (M.fromList subs' <> subs) types ex
               pure (CaseAlt l1 pat' e'))
            pat_exps

--------------------------------------------------------------------------------
-- Provide a substitution

substituteVar :: (Identifiable i, MonadThrow m) => Map Identifier Name -> i -> Location -> m Name
substituteVar subs i0 l =
  case nonrenamableName i0 of
    Nothing -> do
      i <- identifyValue i0
      case M.lookup i subs of
        Just name@ValueName {} -> pure name
        Just name@MethodName {} -> pure name
        Just name@DictName {} -> pure name
        _ -> do
          s <- identifyValue i
          throwM (IdentifierNotInVarScope subs s l)
    Just n -> pure n

substituteClass :: (Identifiable i, MonadThrow m) => Map Identifier Name -> i -> m Name
substituteClass subs i0 =
  do i <- identifyValue i0
     case M.lookup i subs of
       Just name@ClassName{} -> pure name
       _ -> do s <- identifyValue i
               throwM (IdentifierNotInClassScope subs s)

substituteCons :: (Identifiable i, MonadThrow m) => Map Identifier Name -> i -> m Name
substituteCons subs i0 =
  do i <- identifyValue i0
     case M.lookup i subs of
       Just name@ConstructorName{} -> pure name
       _ -> do  throwM (IdentifierNotInConScope subs i)

operatorTable :: [(Identifier, SpecialSigs i -> i)]
operatorTable =
  map
    (first Identifier)
    [ ("+", specialSigsPlus)
    , ("-", specialSigsSubtract)
    , ("*", specialSigsTimes)
    , ("/", specialSigsDivide)
    ]


================================================
FILE: src/Duet/Resolver.hs
================================================
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE Strict #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NoMonomorphismRestriction #-}

-- | Resolve type-class instances.

module Duet.Resolver where

import           Control.Monad.Catch
import           Control.Monad.Supply
import           Data.List
import           Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import           Data.Maybe
import           Duet.Infer
import           Duet.Printer
import           Duet.Supply
import           Duet.Types

resolveTypeClasses
  :: (MonadSupply Int f, MonadThrow f)
  => Map Name (Class Type Name (TypeSignature Type Name l))
  -> SpecialTypes Name
  -> f (Map Name (Class Type Name (TypeSignature Type Name l)))
resolveTypeClasses typeClasses specialTypes = go typeClasses
  where
    go =
      fmap M.fromList .
      mapM
        (\(name, cls) -> do
           is <-
             mapM
               (\inst -> do
                  ms <-
                    mapM
                      (\(nam, (l, alt)) ->
                         fmap ((nam, ) . (l, )) (resolveAlt typeClasses specialTypes alt))
                      (M.toList (dictionaryMethods (instanceDictionary inst)))
                  pure
                    inst
                    { instanceDictionary =
                        (instanceDictionary inst)
                        {dictionaryMethods = M.fromList ms}
                    })
               (classInstances cls)
           pure (name, cls {classInstances = is})) .
      M.toList

resolveBindGroup
  :: (MonadSupply Int m, MonadThrow m)
  => Map Name (Class Type Name (TypeSignature Type Name l))
  -> SpecialTypes Name
  -> BindGroup Type Name (TypeSignature Type Name l)
  -> m (BindGroup Type Name (TypeSignature Type Name l))
resolveBindGroup classes specialTypes (BindGroup explicit implicit) = do
  explicits <- mapM (resolveExplicit classes specialTypes) explicit
  implicits <- mapM (mapM (resolveImplicit classes specialTypes)) implicit
  pure (BindGroup explicits implicits)

resolveImplicit
  :: (MonadSupply Int m, MonadThrow m)
  => Map Name (Class Type Name (TypeSignature Type Name l))
  -> SpecialTypes Name
  -> ImplicitlyTypedBinding Type Name (TypeSignature Type Name l)
  -> m (ImplicitlyTypedBinding Type Name (TypeSignature Type Name l))
resolveImplicit classes specialTypes (ImplicitlyTypedBinding l name alts) =
  ImplicitlyTypedBinding l name <$> mapM (resolveAlt classes specialTypes) alts

resolveExplicit
  :: (MonadSupply Int m, MonadThrow m)
  => Map Name (Class Type Name (TypeSignature Type Name l))
  -> SpecialTypes Name
  -> ExplicitlyTypedBinding Type Name (TypeSignature Type Name l)
  -> m (ExplicitlyTypedBinding Type Name (TypeSignature Type Name l))
resolveExplicit classes specialTypes (ExplicitlyTypedBinding l scheme name alts) =
  ExplicitlyTypedBinding l scheme name <$> mapM (resolveAlt classes specialTypes) alts

resolveAlt
  :: (MonadSupply Int m, MonadThrow m)
  => Map Name (Class Type Name (TypeSignature Type Name l))
  -> SpecialTypes Name
  -> Alternative Type Name (TypeSignature Type Name l)
  -> m (Alternative Type Name (TypeSignature Type Name l))
resolveAlt classes specialTypes (Alternative l ps e) = do
  dicts <-
    mapM
      (\pred' ->
         (pred', ) <$> supplyDictName (predicateToString specialTypes pred'))
      (filter (\p -> (not (isJust (byInst classes p)))) (nub predicates))
  (Alternative l <$> pure ps <*>
   resolveExp
     classes
     specialTypes
     dicts
     (if null dicts
        then e
        else let dictArgs = [VariablePattern l d | (_, d) <- dicts]
             in case e of
                  LambdaExpression _ (Alternative l0 args e0) ->
                    LambdaExpression l (Alternative l0 (dictArgs ++ args) e0)
                  _ -> LambdaExpression l (Alternative l dictArgs e)))
  where
    Forall _ (Qualified predicates _) = typeSignatureScheme l

predicateToString
  :: (Printable i)
  => SpecialTypes i -> Predicate Type i -> String
predicateToString _specialTypes (IsIn name _ts) =
  -- printIdentifier name ++ " " ++ unwords (map (printType specialTypes) ts)
  "?dict" ++ printIdentifier defaultPrint name

resolveExp
  :: (MonadThrow m)
  => Map Name (Class Type Name (TypeSignature Type Name l))
  -> SpecialTypes Name
  -> [(Predicate Type Name, Name)]
  -> Expression Type Name (TypeSignature Type Name l)
  -> m (Expression Type Name (TypeSignature Type Name l))
resolveExp classes _ dicts = go
  where
    go =
      \case
        ParensExpression l e -> ParensExpression l <$> go e
        VariableExpression l i -> do
          dictArgs <- fmap concat (mapM (lookupDictionary l) predicates)
          pure
            (foldl (ApplicationExpression l) (VariableExpression l i) dictArgs)
          where Forall _ (Qualified predicates _) = typeSignatureScheme l
        ApplicationExpression l f x -> ApplicationExpression l <$> go f <*> go x
        InfixExpression l x (i, op) y ->
          InfixExpression l <$> go x <*> fmap (i, ) (go op) <*> go y
        LambdaExpression l0 (Alternative l vs b) ->
          LambdaExpression l0 <$> (Alternative l vs <$> go b)
        CaseExpression l e alts ->
          CaseExpression l <$> go e <*>
          mapM (\(CaseAlt l' p e') -> fmap (CaseAlt l' p) (go e')) alts
        e@ConstructorExpression {} -> pure e
        e@ConstantExpression {} -> pure e
        IfExpression l a b c -> IfExpression l <$> go a <*> go b <*> go c
        e@LiteralExpression {} -> pure e
        LetExpression {} -> error "Let expressions not supported."
    lookupDictionary l p =
      (case byInst classes p of
         Just (preds, dict) -> do
           do parents <- fmap concat (mapM (lookupDictionary l) preds)
              pure (VariableExpression l (dictionaryName dict) : parents)
         Nothing ->
           case lookup p dicts of
             Nothing -> throwM (NoInstanceFor p)
             Just v -> pure [VariableExpression l v])


================================================
FILE: src/Duet/Setup.hs
================================================
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}

-- | Shared application code between commandline and web interface.

module Duet.Setup where

import           Control.Monad
import           Control.Monad.Catch
import           Control.Monad.Supply
import           Data.Map.Strict (Map)
import           Duet.Context
import           Duet.Infer
import           Duet.Renamer
import           Duet.Supply
import           Duet.Types

--------------------------------------------------------------------------------
-- Setting the context

-- | Setup the class environment.
setupEnv
  :: (MonadThrow m, MonadSupply Int m)
  => Map Name (Class Type Name ())
  -> [SpecialTypes Name -> m (DataType Type Name)]
  -> m (Builtins Type Name ())
setupEnv env typeMakers = do
  theArrow <- supplyTypeName "(->)"
  theChar <- supplyTypeName "Char"
  theString <- supplyTypeName "String"
  theInteger <- supplyTypeName "Integer"
  theRational <- supplyTypeName "Rational"
  (true, false, boolDataType) <-
    do name <- supplyTypeName "Bool"
       true <- supplyConstructorName "True"
       false <- supplyConstructorName "False"
       pure
         ( true
         , false
         , DataType
             name
             []
             [DataTypeConstructor true [], DataTypeConstructor false []])
  let function =
        (TypeConstructor
           theArrow
           (FunctionKind StarKind (FunctionKind StarKind StarKind)))
  let specialTypes =
        (SpecialTypes
           { specialTypesBool = boolDataType
           , specialTypesChar = TypeConstructor theChar StarKind
           , specialTypesString = TypeConstructor theString StarKind
           , specialTypesFunction = function
           , specialTypesInteger = TypeConstructor theInteger StarKind
           , specialTypesRational = TypeConstructor theRational StarKind
           })
  (numClass, plus, times) <- makeNumClass function
  (negClass, subtract') <- makeNegClass function
  (fracClass, divide) <- makeFracClass function
  (monoidClass) <- makeMonoidClass function
  (sliceClass) <- makeSliceClass (specialTypesInteger specialTypes) function
  boolSigs <- dataTypeSignatures specialTypes boolDataType
  typesSigs <-
    fmap
      concat
      (mapM ($ specialTypes) typeMakers >>=
       mapM (dataTypeSignatures specialTypes))
  classSigs <-
    fmap
      concat
      (mapM classSignatures [numClass, negClass, fracClass, monoidClass, sliceClass])
  primopSigs <- makePrimOps specialTypes
  let signatures = boolSigs <> classSigs <> primopSigs <> typesSigs
      specialSigs =
        SpecialSigs
          { specialSigsTrue = true
          , specialSigsFalse = false
          , specialSigsPlus = plus
          , specialSigsSubtract = subtract'
          , specialSigsTimes = times
          , specialSigsDivide = divide
          }
      specials = Specials specialSigs specialTypes
  stringSlice <-
    makeInst
      specials
      (IsIn
         (className sliceClass)
         [ConstructorType (specialTypesString specialTypes)])
      [ ( "take"
        , ( ()
          , Alternative
              ()
              []
              (VariableExpression () (PrimopName PrimopStringTake))))
      , ( "drop"
        , ( ()
          , Alternative
              ()
              []
              (VariableExpression () (PrimopName PrimopStringDrop))))
      ]
  stringMonoid <-
    makeInst
      specials
      (IsIn
         (className monoidClass)
         [ConstructorType (specialTypesString specialTypes)])
      [ ( "append"
        , ( ()
          , Alternative
              ()
              []
              (VariableExpression () (PrimopName PrimopStringAppend))))
      , ( "empty"
        , ((), Alternative () [] (LiteralExpression () (StringLiteral ""))))
      ]
  numInt <-
    makeInst
      specials
      (IsIn
         (className numClass)
         [ConstructorType (specialTypesInteger specialTypes)])
      [ ( "times"
        , ( ()
          , Alternative
              ()
              []
              (VariableExpression () (PrimopName PrimopIntegerTimes))))
      , ( "plus"
        , ( ()
          , Alternative
              ()
              []
              (VariableExpression () (PrimopName PrimopIntegerPlus))))
      ]
  negInt <-
    makeInst
      specials
      (IsIn
         (className negClass)
         [ConstructorType (specialTypesInteger specialTypes)])
      [ ( "subtract"
        , ( ()
          , Alternative
              ()
              []
              (VariableExpression () (PrimopName PrimopIntegerSubtract))))
      ]
  numRational <-
    makeInst
      specials
      (IsIn
         (className numClass)
         [ConstructorType (specialTypesRational specialTypes)])
      [ ( "times"
        , ( ()
          , Alternative
              ()
              []
              (VariableExpression () (PrimopName PrimopRationalTimes))))
      , ( "plus"
        , ( ()
          , Alternative
              ()
              []
              (VariableExpression () (PrimopName PrimopRationalPlus))))
      ]
  negRational <-
    makeInst
      specials
      (IsIn
         (className negClass)
         [ConstructorType (specialTypesRational specialTypes)])
      [ ( "subtract"
        , ( ()
          , Alternative
              ()
              []
              (VariableExpression () (PrimopName PrimopRationalSubtract))))
      ]
  fracRational <-
    makeInst
      specials
      (IsIn
         (className fracClass)
         [ConstructorType (specialTypesRational specialTypes)])
      [ ( "divide"
        , ( ()
          , Alternative
              ()
              []
              (VariableExpression () (PrimopName PrimopRationalDivide))))
      ]
  env' <-
    let update =
          addClass numClass >=>
          addClass negClass >=>
          addClass fracClass >=>
          addClass monoidClass >=>
          addClass sliceClass >=>
          addInstance numInt >=>
          addInstance negInt >=>
          addInstance stringMonoid >=>
          addInstance stringSlice >=>
          addInstance fracRational >=>
          addInstance negRational >=> addInstance numRational
     in update env
  pure
    Builtins
      { builtinsSpecialSigs = specialSigs
      , builtinsSpecialTypes = specialTypes
      , builtinsSignatures = signatures
      , builtinsTypeClasses = env'
      }

--------------------------------------------------------------------------------
-- Builtin classes and primops

makePrimOps
  :: (MonadSupply Int m)
  => SpecialTypes Name -> m [TypeSignature Type Name Name]
makePrimOps SpecialTypes {..} = do
  let sigs =
        map
          ((\case
              PrimopIntegerPlus ->
                TypeSignature
                  (PrimopName PrimopIntegerPlus)
                  (toScheme (integer --> integer --> integer))
              PrimopIntegerSubtract ->
                TypeSignature
                  (PrimopName PrimopIntegerSubtract)
                  (toScheme (integer --> integer --> integer))
              PrimopIntegerTimes ->
                TypeSignature
                  (PrimopName PrimopIntegerTimes)
                  (toScheme (integer --> integer --> integer))
              PrimopRationalDivide ->
                TypeSignature
                  (PrimopName PrimopRationalDivide)
                  (toScheme (rational --> rational --> rational))
              PrimopRationalPlus ->
                TypeSignature
                  (PrimopName PrimopRationalPlus)
                  (toScheme (rational --> rational --> rational))
              PrimopRationalSubtract ->
                TypeSignature
                  (PrimopName PrimopRationalSubtract)
                  (toScheme (rational --> rational --> rational))
              PrimopRationalTimes ->
                TypeSignature
                  (PrimopName PrimopRationalTimes)
                  (toScheme (rational --> rational --> rational))
              PrimopStringAppend ->
                TypeSignature
                  (PrimopName PrimopStringAppend)
                  (toScheme (string --> string --> string))
              PrimopStringTake ->
                TypeSignature
                  (PrimopName PrimopStringTake)
                  (toScheme (integer --> string --> string))
              PrimopStringDrop ->
                TypeSignature
                  (PrimopName PrimopStringDrop)
                  (toScheme (integer --> string --> string))))
          [minBound .. maxBound]
  pure sigs
  where
    integer = ConstructorType specialTypesInteger
    rational = ConstructorType specialTypesRational
    string = ConstructorType specialTypesString
    infixr 1 -->
    (-->) :: Type Name -> Type Name -> Type Name
    a --> b =
      ApplicationType
        (ApplicationType (ConstructorType specialTypesFunction) a)
        b

makeNumClass :: MonadSupply Int m => TypeConstructor Name -> m (Class Type Name l, Name, Name)
makeNumClass function = do
  a <- fmap (\n -> TypeVariable n StarKind) (supplyTypeName "a")
  let a' = VariableType a
  plus <- supplyMethodName "plus"
  times <- supplyMethodName "times"
  cls <-
    makeClass
      "Num"
      [a]
      [ (plus, Forall [a] (Qualified [] (a' --> a' --> a')))
      , (times, Forall [a] (Qualified [] (a' --> a' --> a')))
      ]
  pure (cls, plus, times)
  where
    infixr 1 -->
    (-->) :: Type Name -> Type Name -> Type Name
    a --> b = ApplicationType (ApplicationType (ConstructorType function) a) b

makeNegClass :: MonadSupply Int m => TypeConstructor Name -> m (Class Type Name l, Name)
makeNegClass function = do
  a <- fmap (\n -> TypeVariable n StarKind) (supplyTypeName "a")
  let a' = VariableType a
  negate' <- supplyMethodName "negate"
  subtract' <- supplyMethodName "subtract"
  abs' <- supplyMethodName "abs"
  cls <-
    makeClass
      "Neg"
      [a]
      [ (negate', Forall [a] (Qualified [] (a' --> a' --> a')))
      , (subtract', Forall [a] (Qualified [] (a' --> a' --> a')))
      , (abs', Forall [a] (Qualified [] (a' --> a')))
      ]
  pure (cls, subtract')
  where
    infixr 1 -->
    (-->) :: Type Name -> Type Name -> Type Name
    a --> b = ApplicationType (ApplicationType (ConstructorType function) a) b

makeFracClass :: MonadSupply Int m => TypeConstructor Name -> m (Class Type Name l, Name)
makeFracClass function = do
  a <- fmap (\n -> TypeVariable n StarKind) (supplyTypeName "a")
  let a' = VariableType a
  divide <- supplyMethodName "divide"
  recip' <- supplyMethodName "recip"
  cls <-
    makeClass
      "Fractional"
      [a]
      [ (divide, Forall [a] (Qualified [] (a' --> a' --> a')))
      , (recip', Forall [a] (Qualified [] (a' --> a')))
      ]
  pure (cls, divide)
  where
    infixr 1 -->
    (-->) :: Type Name -> Type Name -> Type Name
    a --> b = ApplicationType (ApplicationType (ConstructorType function) a) b

makeMonoidClass :: MonadSupply Int m => TypeConstructor Name -> m (Class Type Name l)
makeMonoidClass function = do
  a <- fmap (\n -> TypeVariable n StarKind) (supplyTypeName "a")
  let a' = VariableType a
  append <- supplyMethodName "append"
  empty <- supplyMethodName "empty"
  cls <-
    makeClass
      "Monoid"
      [a]
      [ (append, Forall [a] (Qualified [] (a' --> a' --> a')))
      , (empty, Forall [a] (Qualified [] (a')))
      ]
  pure cls
  where
    infixr 1 -->
    (-->) :: Type Name -> Type Name -> Type Name
    a --> b = ApplicationType (ApplicationType (ConstructorType function) a) b

makeSliceClass :: MonadSupply Int m => TypeConstructor Name -> TypeConstructor Name -> m (Class Type Name l)
makeSliceClass integer' function = do
  a <- fmap (\n -> TypeVariable n StarKind) (supplyTypeName "a")
  let a' = VariableType a
  drop' <- supplyMethodName "drop"
  take' <- supplyMethodName "take"
  cls <-
    makeClass
      "Slice"
      [a]
      [ (drop', Forall [a] (Qualified [] (ConstructorType integer' --> (a' --> a'))))
      , (take', Forall [a] (Qualified [] (ConstructorType integer' --> (a' --> a'))))
      ]
  pure cls
  where
    infixr 1 -->
    (-->) :: Type Name -> Type Name -> Type Name
    a --> b = ApplicationType (ApplicationType (ConstructorType function) a) b


================================================
FILE: src/Duet/Simple.hs
================================================
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}

-- |

module Duet.Simple where

import Control.Monad
import Control.Monad.Catch
import Control.Monad.Supply
import Control.Monad.Writer
import Duet.Context
import Duet.Infer
import Duet.Printer
import Duet.Renamer
import Duet.Resolver
import Duet.Setup
import Duet.Stepper
import Duet.Types

-- | Create a context of all renamed, checked and resolved code.
createContext
  :: (MonadSupply Int m, MonadCatch m)
  => [Decl UnkindedType Identifier Location]
  -> m ([BindGroup Type Name (TypeSignature Type Name Location)], Context Type Name Location)
createContext decls = do
  do builtins <-
       setupEnv mempty [] >>=
       traverse
         (const
            (pure
               (Location
                  { locationStartLine = 0
                  , locationStartColumn = 0
                  , locationEndLine = 0
                  , locationEndColumn = 0
                  })))
     let specials = builtinsSpecials builtins
     catch
       (do (typeClasses, signatures, renamedBindings, scope, dataTypes) <-
             renameEverything decls specials builtins
           -- Type class definition
           addedTypeClasses <- addClasses builtins typeClasses
               -- Type checking
           (bindGroups, typeCheckedClasses) <-
             typeCheckModule
               addedTypeClasses
               signatures
               (builtinsSpecialTypes builtins)
               renamedBindings
           -- Type class resolution
           resolvedTypeClasses <-
             resolveTypeClasses
               typeCheckedClasses
               (builtinsSpecialTypes builtins)
           resolvedBindGroups <-
             mapM
               (resolveBindGroup
                  resolvedTypeClasses
                  (builtinsSpecialTypes builtins))
               bindGroups
           -- Create a context of everything
           let ctx =
                 Context
                   { contextSpecialSigs = builtinsSpecialSigs builtins
                   , contextSpecialTypes = builtinsSpecialTypes builtins
                   , contextSignatures = signatures
                   , contextScope = scope
                   , contextTypeClasses = resolvedTypeClasses
                   , contextDataTypes = dataTypes
                   }
           pure (resolvedBindGroups, ctx))
       (throwM . ContextException (builtinsSpecialTypes builtins))

-- | Run the substitution model on the code.
runStepper
  :: forall m. (MonadWriter [Expression Type Name ()] m, MonadSupply Int m, MonadThrow m)
  => Int
  -> Context Type Name Location
  -> [BindGroup Type Name Location]
  -> String
  -> m ()
runStepper maxSteps ctx bindGroups' i = do
  e0 <- lookupNameByString i bindGroups'
  loop 1 "" e0
  where
    loop ::
         Int
      -> String
      -> Expression Type Name Location
      -> m ()
    loop count lastString e = do
      e' <- expandSeq1 ctx bindGroups' e
      let string = printExpression (defaultPrint) e
      when (string /= lastString) (tell [fmap (const ()) e])
      if (fmap (const ()) e' /= fmap (const ()) e) && count < maxSteps
        then do
          newE <-
            renameExpression
              (contextSpecials ctx)
              (contextScope ctx)
              (contextDataTypes ctx)
              e'
          loop (count + 1) string newE
        else pure ()


================================================
FILE: src/Duet/Stepper.hs
================================================
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Strict #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE LambdaCase #-}

-- | The substitution stepper.

module Duet.Stepper
  ( expandSeq1
  , fargs
  , lookupNameByString
  ) where

import           Control.Applicative
import           Control.Monad.Catch
import           Control.Monad.State
import           Data.List
import           Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import           Data.Maybe
import           Duet.Types

--------------------------------------------------------------------------------
-- Expansion

expandSeq1
  :: (MonadThrow m)
  => Context Type Name (Location)
  -> [BindGroup Type Name (Location)]
  -> Expression Type Name (Location)
  -> m (Expression Type Name (Location))
expandSeq1 (Context { contextTypeClasses = typeClassEnv
                    , contextSpecialSigs = specialSigs
                    , contextSignatures = signatures
                    }) b e = evalStateT (go e) False
  where
    go =
      \case
        e0
          -- If we're looking at a constructor, then force the args.
          | (ce@(ConstructorExpression l _), args) <- fargs e0 -> do
            args' <- mapM go args
            pure (foldl (ApplicationExpression l) ce args')
          -- If we're looking at a constant (hole), then force the args.
          | (ce@(ConstantExpression l _), args) <- fargs e0 -> do
            args' <- mapM go args
            pure (foldl (ApplicationExpression l) ce args')
          -- We're looking at a general expression, check if a force
          -- has already happpened. If so, we just return the
          -- identity.
          | otherwise -> do
            alreadyExpanded <- get
            if alreadyExpanded
              then pure e0
              else do
                -- If we haven't expanded anything yet, let's expand
                -- this mother to whnf.
                e' <- lift (expandWhnf typeClassEnv specialSigs signatures e0 b)
                -- If the expansion did actually produce a new AST
                -- then count that as an expansion.
                put (e' /= e0)
                pure e'

expandWhnf
  :: MonadThrow m
  => Map Name (Class Type Name (TypeSignature Type Name Location))
  -> SpecialSigs Name
  -> [TypeSignature Type Name Name]
  -> Expression Type Name (Location)
  -> [BindGroup Type Name (Location)]
  -> m (Expression Type Name (Location))
expandWhnf typeClassEnv specialSigs signatures e b = go e
  where
    go x =
      case x of
        -- Parens aren't an expansion step, just a grouping.
        ParensExpression _ e -> go e
        VariableExpression _ i -> do
          case find ((== i) . typeSignatureA) signatures of
            Nothing -> do
              e' <- lookupName i b
              pure e'
            Just {} -> pure x
        LiteralExpression {} -> return x
        ConstructorExpression {} -> return x
        ConstantExpression {} -> return x
        ApplicationExpression l (ApplicationExpression l1 op@(VariableExpression _ (PrimopName primop)) x) y ->
          case x of
            LiteralExpression _ (StringLiteral sx) ->
              case y of
                LiteralExpression _ (StringLiteral sy) ->
                  case primop of
                    PrimopStringAppend ->
                      pure (LiteralExpression l (StringLiteral (sx <> sy)))
                    _ -> error "Runtime type error that should not occurr"
                _ -> do
                  y' <- go y
                  pure
                    (ApplicationExpression l (ApplicationExpression l1 op x) y')
            LiteralExpression _ (IntegerLiteral n) ->
              case y of
                LiteralExpression _ (StringLiteral sy) ->
                  case primop of
                    PrimopStringTake ->
                      pure (LiteralExpression l (StringLiteral (genericTake n sy)))
                    PrimopStringDrop ->
                      pure (LiteralExpression l (StringLiteral (genericDrop n sy)))
                    _ -> error "Runtime type error that should not occurr"
                _ -> do
                  y' <- go y
                  pure
                    (ApplicationExpression l (ApplicationExpression l1 op x) y')
            _ -> do
              x' <- go x
              pure (ApplicationExpression l (ApplicationExpression l1 op x') y)
        ApplicationExpression l func arg ->
          case func of
            LambdaExpression l0 (Alternative l' params body) ->
              case params of
                (VariablePattern _ param:params') ->
                  let body' = substitute param arg body
                  in case params' of
                       [] -> pure body'
                       _ ->
                         pure
                           (LambdaExpression l0 (Alternative l' params' body'))
                _ -> error "Unsupported lambda."
            VariableExpression _ (MethodName _ methodName) ->
              case arg of
                VariableExpression _ dictName@DictName {} ->
                  case find
                         ((== dictName) . dictionaryName)
                         (concatMap
                            (map instanceDictionary . classInstances)
                            (M.elems typeClassEnv)) of
                    Nothing -> throwM (CouldntFindMethodDict dictName)
                    Just dict ->
                      case M.lookup
                             methodName
                             (M.mapKeys
                                (\(MethodName _ s) -> s)
                                (dictionaryMethods dict)) of
                        Nothing ->
                          error
                            ("Missing method " ++
                             show methodName ++ " in dictionary: " ++ show dict)
                        Just (_, Alternative _ _ e) -> pure (fmap typeSignatureA e)
                _ -> error "Unsupported variable expression."
            _ -> do
              func' <- go func
              pure (ApplicationExpression l func' arg)
        orig@(InfixExpression l x op@(_s, VariableExpression _ (PrimopName primop)) y) ->
          case x of
            LiteralExpression _ x' ->
              case y of
                LiteralExpression _ y' ->
                  case (x', y') of
                    (IntegerLiteral i1, IntegerLiteral i2) ->
                      pure
                        (LiteralExpression
                           l
                           (case primop of
                              PrimopIntegerPlus -> IntegerLiteral (i1 + i2)
                              PrimopIntegerTimes -> IntegerLiteral (i1 * i2)
                              PrimopIntegerSubtract -> IntegerLiteral (i1 - i2)
                              _ -> error "Unexpected operation for integer literals."))
                    (RationalLiteral i1, RationalLiteral i2) ->
                      pure
                        (LiteralExpression
                           l
                           (case primop of
                              PrimopRationalPlus -> RationalLiteral (i1 + i2)
                              PrimopRationalTimes -> RationalLiteral (i1 * i2)
                              PrimopRationalSubtract ->
                                RationalLiteral (i1 - i2)
                              PrimopRationalDivide -> RationalLiteral (i1 / i2)
                              _ -> error "Unexpected operation for rational literals."))
                    _ -> pure orig
                _ -> do
                  y' <- go y
                  pure (InfixExpression l x op y')
            _ -> do
              x' <- go x
              pure (InfixExpression l x' op y)
        InfixExpression l x (s, op) y -> do
          op' <- go op
          pure (InfixExpression l x (s, op') y)
        IfExpression l pr th el ->
          case pr of
            ConstructorExpression _ n
              | n == specialSigsTrue specialSigs -> pure th
              | n == specialSigsFalse specialSigs -> pure el
            _ -> IfExpression l <$> go pr <*> pure th <*> pure el
        LetExpression {} -> return x
        LambdaExpression {} -> return x
        CaseExpression l e0 alts ->
          let matches =
                map
                  (\ca -> (match e0 (caseAltPattern ca), caseAltExpression ca))
                  alts
          in case listToMaybe
                    (mapMaybe
                       (\(r, e) -> do
                          case r of
                            OK v -> pure (v, e)
                            Fail -> Nothing)
                       matches) of
               Just (Success subs, expr) ->
                 return
                   (foldr
                      (\(name, that) expr' -> substitute name that expr')
                      expr
                      subs)
               Just (NeedsMoreEval is, _) -> do
                 e' <- expandAt typeClassEnv is specialSigs signatures e0 b
                 pure (CaseExpression l e' alts)
               Nothing -> error ("Incomplete pattern match... " ++ show matches)

expandAt
  :: MonadThrow m
  => Map Name (Class Type Name (TypeSignature Type Name Location))
  -> [Int]
  -> SpecialSigs Name
  -> [TypeSignature Type Name Name]
  -> Expression Type Name (Location)
  -> [BindGroup Type Name (Location)]
  -> m (Expression Type Name (Location))
expandAt typeClassEnv is specialSigs signatures e0 b  = go [0] e0
  where
    go js e =
      if is == js
        then expandWhnf typeClassEnv specialSigs signatures e b
        else case e of
               _
                 | (ce@(ConstructorExpression l _), args) <- fargs e -> do
                   args' <-
                     sequence
                       (zipWith (\i arg -> go (js ++ [i]) arg) [0 ..] args)
                   pure (foldl (ApplicationExpression l) ce args')
                 | otherwise -> pure e

--------------------------------------------------------------------------------
-- Pattern matching

match
  :: (Eq i)
  => Expression Type i l -> Pattern Type i l -> Result (Match Type i l)
match = go [0]
  where
    go is val pat =
      case pat of
        BangPattern p
          | isWhnf val -> go is val p
          | otherwise -> OK (NeedsMoreEval is)
        AsPattern _l ident pat ->
          case go is val pat of
            OK (Success binds) -> OK (Success ((ident, val) : binds))
            res -> res
        WildcardPattern _ _ -> OK (Success [])
        VariablePattern _ i -> OK (Success [(i, val)])
        LiteralPattern _ l ->
          case val of
            LiteralExpression _ l'
              | l' == l -> OK (Success [])
              | otherwise -> Fail
            _ -> OK (NeedsMoreEval is)
        ConstructorPattern _ i pats
          | (constructor@ConstructorExpression {}, args) <- fargs val ->
            if fmap (const ()) constructor == ConstructorExpression () i
              then if length args == length pats
                     then foldl
                            (<>)
                            (OK (Success []))
                            (zipWith
                               (\j (arg, p) -> go (is ++ [j]) arg p)
                               [0 ..]
                               (zip args pats))
                     else Fail
              else Fail
          | otherwise -> OK (NeedsMoreEval is)

isWhnf :: Expression Type i l -> Bool
isWhnf =
  \case
    VariableExpression {} -> False
    ConstructorExpression {} -> True
    ConstantExpression {} -> True
    LiteralExpression {} -> True
    ApplicationExpression {} -> False
    InfixExpression {} -> False
    LetExpression {} -> False
    LambdaExpression {} -> True
    IfExpression {} -> False
    CaseExpression {} -> False
    ParensExpression {} -> False

--------------------------------------------------------------------------------
-- Expression manipulators

-- | Flatten an application f x y into (f,[x,y]).
fargs :: Expression Type i l -> (Expression Type i l, [(Expression Type i l)])
fargs e = go e []
  where
    go (ApplicationExpression _ f x) args = go f (x : args)
    go f args = (f, args)

--------------------------------------------------------------------------------
-- Substitutions

substitute :: Eq i => i -> Expression Type i l -> Expression Type i l -> Expression Type i l
substitute i arg = go
  where
    go =
      \case
        VariableExpression l i'
          | i == i' -> arg
          | otherwise -> VariableExpression l i'
        x@ConstructorExpression {} -> x
        x@ConstantExpression {} -> x
        ParensExpression _ e -> go e
        ApplicationExpression l f x -> ApplicationExpression l (go f) (go x)
        InfixExpression l x (s, f) y -> InfixExpression l (go x) (s, go f) (go y)
        LetExpression {} -> error "let expressions unsupported."
        CaseExpression l e cases ->
          CaseExpression l (go e) (map (\(CaseAlt l pat e') -> CaseAlt l pat (go e')) cases)
        IfExpression l a b c -> IfExpression l (go a) (go b) (go c)
        x@LiteralExpression {} -> x
        LambdaExpression l (Alternative l' args body) ->
          LambdaExpression l (Alternative l' args (go body))

--------------------------------------------------------------------------------
-- Lookups

lookupName
  :: (MonadThrow m)
  => Name
  -> [BindGroup Type Name (Location)]
  -> m (Expression Type Name (Location))
lookupName identifier binds =
  case listToMaybe (mapMaybe findIdent binds) of
    Nothing -> throwM (CouldntFindName identifier)
    Just i -> pure i
  where
    findIdent (BindGroup es is) =
      listToMaybe
        (mapMaybe
           (\case
              ImplicitlyTypedBinding _ (i, _) [Alternative _ [] e]
                | i == identifier -> Just e
              _ -> Nothing)
           (concat is)) <|>
      listToMaybe
        (mapMaybe
           (\case
              ExplicitlyTypedBinding _ (i, _) _ [Alternative _ [] e]
                | i == identifier -> Just e
              _ -> Nothing)
           es)

lookupNameByString
  :: (MonadThrow m)
  => String
  -> [BindGroup Type Name (Location)]
  -> m (Expression Type Name (Location))
lookupNameByString identifier binds =
  case listToMaybe (mapMaybe findIdent binds) of
    Nothing -> throwM (CouldntFindNameByString identifier)
    Just i -> pure i
  where
    findIdent (BindGroup es is) =
      listToMaybe
        (mapMaybe
           (\case
              ImplicitlyTypedBinding _ (ValueName _ i, _) [Alternative _ [] e]
                | i == identifier -> Just e
              _ -> Nothing)
           (concat is)) <|>
      listToMaybe
        (mapMaybe
           (\case
              ExplicitlyTypedBinding _ (ValueName _ i, _) _ [Alternative _ [] e]
                | i == identifier -> Just e
              _ -> Nothing)
           es)


================================================
FILE: src/Duet/Supply.hs
================================================
{-# LANGUAGE Strict #-}
{-# LANGUAGE FlexibleContexts #-}
-- |

module Duet.Supply where

import Control.Monad.Catch
import Control.Monad.Supply
import Duet.Types

supplyValueName :: (MonadSupply Int m, Identifiable i, MonadThrow m) => i -> m Name
supplyValueName s = do
  i <- supply
  Identifier s' <- identifyValue s
  return (ValueName i s')

supplyConstructorName :: (MonadSupply Int m) => Identifier -> m Name
supplyConstructorName (Identifier s) = do
  i <- supply
  return (ConstructorName i s)

supplyDictName :: (MonadSupply Int m) => String -> m Name
supplyDictName s = do
  i <- supply
  return (DictName i s)

supplyDictName' :: (MonadSupply Int m, MonadThrow m) => Identifier -> m Name
supplyDictName' s = do
  i <- supply
  Identifier s' <- identifyValue s
  return (DictName i s')

supplyTypeName :: (MonadSupply Int m) => Identifier -> m Name
supplyTypeName (Identifier s) = do
  i <- supply
  return (TypeName i s)

supplyTypeVariableName :: (MonadSupply Int m) => Identifier -> m Name
supplyTypeVariableName (Identifier s) = do
  i <- supply
  return (TypeName i (s ++ show i))

supplyClassName :: (MonadSupply Int m) => Identifier -> m Name
supplyClassName (Identifier s) = do
  i <- supply
  return (ClassName i s)

supplyMethodName :: (MonadSupply Int m) => Identifier -> m Name
supplyMethodName (Identifier s) = do
  i <- supply
  return (MethodName i s)


================================================
FILE: src/Duet/Tokenizer.hs
================================================
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}

-- | Duet syntax tokenizer.

module Duet.Tokenizer where

import           Control.Monad
import           Data.Char
import           Data.List
import           Data.Text (Text)
import qualified Data.Text as T
import           Duet.Printer
import           Duet.Types
import           Text.Parsec hiding (anyToken)
import           Text.Parsec.Text
import           Text.Printf

tokenize :: FilePath -> Text -> Either ParseError [(Token, Location)]
tokenize fp t = parse tokensTokenizer fp t

tokensTokenizer :: Parser [(Token, Location)]
tokensTokenizer =
  manyTill (many space >>= tokenTokenizer) (try (spaces >> eof))

tokenTokenizer :: [Char] -> Parser (Token, Location)
tokenTokenizer prespaces =
  choice
    [ if isSuffixOf "\n" prespaces
        then do
          pos <- getPosition
          pure
            ( NonIndentedNewline
            , Location
                (sourceLine pos)
                (sourceColumn pos)
                (sourceLine pos)
                (sourceColumn pos))
        else unexpected "indented newline"
    , atomThenSpace If "if"
    , atomThenSpace Then "then"
    , atomThenSpace ClassToken "class"
    , atomThenSpace InstanceToken "instance"
    , atomThenSpace Where "where"
    , atomThenSpace Data "data"
    , atomThenSpace Else "else"
    , atomThenSpace ForallToken "forall"
    , atomThenSpace Case "case"
    , atomThenSpace Of "of"
    , atom Bang "!"
    , atom Period "."
    , atom Backslash "\\"
Download .txt
gitextract_m0lahraa/

├── .gitignore
├── Dockerfile
├── LICENSE.md
├── README.md
├── app/
│   └── Main.hs
├── duet.cabal
├── examples/
│   ├── ack.hs
│   ├── arith.hs
│   ├── bound.hs
│   ├── builtins.hs
│   ├── classes.hs
│   ├── fac.hs
│   ├── factorial.hs
│   ├── folds-strictness.hs
│   ├── folds.hs
│   ├── functor-class.hs
│   ├── gabriel-eq-reason.hs
│   ├── good.hs
│   ├── integers.hs
│   ├── lists.hs
│   ├── monad.hs
│   ├── monoid.hs
│   ├── ord.hs
│   ├── parser.hs
│   ├── pattern-matching.hs
│   ├── placeholders.hs
│   ├── prelude.hs
│   ├── seq.hs
│   ├── sicp.hs
│   ├── simple-class.hs
│   ├── state.hs
│   ├── strict-folds.hs
│   ├── string-pats.hs
│   ├── string-substring.hs
│   ├── syntax-buffet.hs
│   └── terminal.hs
├── src/
│   ├── Control/
│   │   └── Monad/
│   │       └── Supply.hs
│   └── Duet/
│       ├── Context.hs
│       ├── Errors.hs
│       ├── Infer.hs
│       ├── Parser.hs
│       ├── Printer.hs
│       ├── Renamer.hs
│       ├── Resolver.hs
│       ├── Setup.hs
│       ├── Simple.hs
│       ├── Stepper.hs
│       ├── Supply.hs
│       ├── Tokenizer.hs
│       └── Types.hs
├── stack.yaml
└── test/
    ├── Main.hs
    └── Spec.hs
Condensed preview — 53 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (254K chars).
[
  {
    "path": ".gitignore",
    "chars": 12,
    "preview": ".stack-work\n"
  },
  {
    "path": "Dockerfile",
    "chars": 752,
    "preview": "FROM frolvlad/alpine-gcc as base\n\nRUN apk add --no-cache ghc curl git\n\nRUN curl -L https://github.com/nh2/stack/releases"
  },
  {
    "path": "LICENSE.md",
    "chars": 1746,
    "preview": "*Duet* is Copyright (c) Chris Done 2017.\n\n*Typing Haskell in Haskell*, which provides the groundwork for Duet's\ntype sys"
  },
  {
    "path": "README.md",
    "chars": 5421,
    "preview": "# <img src=images/duet.svg height=36> Duet\n\nA tiny language, a subset of Haskell (with type classes) aimed at aiding tea"
  },
  {
    "path": "app/Main.hs",
    "chars": 6344,
    "preview": "{-# LANGUAGE TemplateHaskell #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE Rec"
  },
  {
    "path": "duet.cabal",
    "chars": 1542,
    "preview": "name:\n  duet\nversion:\n  0.0.2\ncabal-version:\n  >=1.10\nbuild-type:\n  Simple\nmaintainer:\n  chrisdone@gmail.com\nsynopsis:\n "
  },
  {
    "path": "examples/ack.hs",
    "chars": 177,
    "preview": "data Tuple a b = Tuple a b\n\nack = \\m n ->\n  case Tuple m n of\n    Tuple 0 n -> n + 1\n    Tuple m 0 -> ack (m - 1) 1\n    "
  },
  {
    "path": "examples/arith.hs",
    "chars": 19,
    "preview": "main = 22.0 + 33.0\n"
  },
  {
    "path": "examples/bound.hs",
    "chars": 173,
    "preview": "class Bounded a where\n  minBound :: a\n  maxBound :: a\ninstance Bounded Bool where\n  minBound = False\n  maxBound = True\nd"
  },
  {
    "path": "examples/builtins.hs",
    "chars": 160,
    "preview": "data X = X Integer Char Rational String\nclass Show a where show :: a -> String\ninstance Show Integer where show = \\_ -> "
  },
  {
    "path": "examples/classes.hs",
    "chars": 1033,
    "preview": "class Reader a where\n  reader :: List Ch -> a\nclass Shower a where\n  shower :: a -> List Ch\ninstance Shower Nat where\n  "
  },
  {
    "path": "examples/fac.hs",
    "chars": 482,
    "preview": "factorial = \\n -> case n of\n                    0 -> 1\n                    1 -> 1\n                    _ -> n * factorial"
  },
  {
    "path": "examples/factorial.hs",
    "chars": 280,
    "preview": "data N = S N | Z | M N N\nsub = \\n -> case n of\n              S c -> c\nfac = \\n -> case n of\n              Z -> S Z\n     "
  },
  {
    "path": "examples/folds-strictness.hs",
    "chars": 460,
    "preview": "data List a = Nil | Cons a (List a)\nfoldr = \\f z l ->\n  case l of\n    Nil -> z\n    Cons x xs -> f x (foldr f z xs)\nfoldl"
  },
  {
    "path": "examples/folds.hs",
    "chars": 295,
    "preview": "data List a = Nil | Cons a (List a)\nfoldr = \\f z l ->\n  case l of\n    Nil -> z\n    Cons x xs -> f x (foldr f z xs)\nfoldl"
  },
  {
    "path": "examples/functor-class.hs",
    "chars": 325,
    "preview": "data Maybe a = Nothing | Just a\nclass Functor (f :: Type -> Type) where\n  map :: (a -> b) -> f a -> f b\ninstance Functor"
  },
  {
    "path": "examples/gabriel-eq-reason.hs",
    "chars": 772,
    "preview": "data IO a = Print Nat (IO a) | Return a\n\ndata Nat = Z | S Nat\n\ndata List a = Nil | Cons a (List a)\n\ndata Unit = Unit\n\nbi"
  },
  {
    "path": "examples/good.hs",
    "chars": 252,
    "preview": "class Good a where\n  good :: a -> Bool\ndata Maybe a = Just a | Nothing\ninstance Good Bool where\n  good = \\x -> x\ninstanc"
  },
  {
    "path": "examples/integers.hs",
    "chars": 26,
    "preview": "main = 3 + ((2 + -3) - 3)\n"
  },
  {
    "path": "examples/lists.hs",
    "chars": 233,
    "preview": "data List a = Nil | Cons a (List a)\nmap = \\f xs ->\n  case xs of\n    Nil -> Nil\n    Cons x xs -> Cons (f x) (map f xs)\nli"
  },
  {
    "path": "examples/monad.hs",
    "chars": 585,
    "preview": "class Monad (m :: Type -> Type) where\n  bind :: m a -> (a -> m b) -> m b\nclass Applicative (f :: Type -> Type) where\n  p"
  },
  {
    "path": "examples/monoid.hs",
    "chars": 300,
    "preview": "class Monoid a where\n  mempty  :: a\n  mappend :: a -> a -> a\ndata List a = Nil | Cons a (List a)\ninstance Monoid (List a"
  },
  {
    "path": "examples/ord.hs",
    "chars": 466,
    "preview": "class Ord a  where\n  compare :: a -> a -> Ordering\ndata Ordering\n  = EQ\n  | LT\n  | GT\ninstance Ord Ordering where\n  comp"
  },
  {
    "path": "examples/parser.hs",
    "chars": 816,
    "preview": "data Tuple a b = Tuple a b\ndata Result a = OK a String | Error String\ndata Parser a = Parser (String -> Result a)\nparseB"
  },
  {
    "path": "examples/pattern-matching.hs",
    "chars": 294,
    "preview": "data Uk = Manchester | Bristol\n\ndata Italy = Trento | Padova\n\ndata Europe = Uk Uk | Italy Italy\n\nbristol = Bristol\n\nmain"
  },
  {
    "path": "examples/placeholders.hs",
    "chars": 256,
    "preview": "data List a = Nil | Cons a (List a)\nfoldr = \\f z l ->\n  case l of\n    Nil -> z\n    Cons x xs -> f x (foldr f z xs)\nfoldl"
  },
  {
    "path": "examples/prelude.hs",
    "chars": 781,
    "preview": "data Bool = True | False\n\ndata Ordering = EQ | LT | GT\n\nclass Eq a where\n  equal :: a -> a -> Bool\n  notEqual :: a -> a "
  },
  {
    "path": "examples/seq.hs",
    "chars": 93,
    "preview": "seq :: a -> b -> b\nseq =\n  \\x y ->\n    case x of\n      !_ -> y\nloop = loop\nmain = seq loop 1\n"
  },
  {
    "path": "examples/sicp.hs",
    "chars": 47,
    "preview": "square = \\x -> x * x\nit = square 6 + square 10\n"
  },
  {
    "path": "examples/simple-class.hs",
    "chars": 138,
    "preview": "class X a where\n f :: a -> D\ndata D = D | C\ninstance X D where\n f = \\x -> case x of\n             D -> D\n             C -"
  },
  {
    "path": "examples/state.hs",
    "chars": 1052,
    "preview": "data Unit = Unit\nclass Monad (m :: Type -> Type) where\n  bind :: m a -> (a -> m b) -> m b\nclass Applicative (f :: Type -"
  },
  {
    "path": "examples/strict-folds.hs",
    "chars": 307,
    "preview": "data List a = Nil | Cons a (List a)\nfoldr = \\f z l ->\n  case l of\n    Nil -> z\n    Cons x xs -> f x (foldr f z xs)\nfoldl"
  },
  {
    "path": "examples/string-pats.hs",
    "chars": 53,
    "preview": "main =\n  case \"foo\" of\n    \"bar\" -> 0\n    \"foo\" -> 1\n"
  },
  {
    "path": "examples/string-substring.hs",
    "chars": 52,
    "preview": "main = append (take 2 (drop 7 \"Hello, World!\")) \"!\"\n"
  },
  {
    "path": "examples/syntax-buffet.hs",
    "chars": 1080,
    "preview": "class Reader a where\n  reader :: List Ch -> a\nclass Shower a where\n  shower :: a -> List Ch\ninstance Shower Nat where\n  "
  },
  {
    "path": "examples/terminal.hs",
    "chars": 209,
    "preview": "data Terminal a\n = GetLine (String -> Terminal a)\n | PutStrLn String (Terminal a)\n | Pure a\n\nmain =\n  PutStrLn\n    \"Plea"
  },
  {
    "path": "src/Control/Monad/Supply.hs",
    "chars": 2430,
    "preview": "{-# LANGUAGE CPP #-}\n{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE FunctionalDependencies #-}\n{-# LANGUAGE Undecidable"
  },
  {
    "path": "src/Duet/Context.hs",
    "chars": 6253,
    "preview": "{-# LANGUAGE LambdaCase #-}\n{-# LANGUAGE FlexibleContexts #-}\n\n-- | Functions for setting up the context.\n\nmodule Duet.C"
  },
  {
    "path": "src/Duet/Errors.hs",
    "chars": 6634,
    "preview": "{-# LANGUAGE LambdaCase #-}\n\n-- |\n\nmodule Duet.Errors where\n\nimport           Control.Exception\nimport           Data.Ch"
  },
  {
    "path": "src/Duet/Infer.hs",
    "chars": 41695,
    "preview": "{-# LANGUAGE KindSignatures #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE Tupl"
  },
  {
    "path": "src/Duet/Parser.hs",
    "chars": 29725,
    "preview": "{-# LANGUAGE TupleSections #-}\n{-# LANGUAGE OverloadedStrings #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE NoMo"
  },
  {
    "path": "src/Duet/Printer.hs",
    "chars": 16186,
    "preview": "{-# LANGUAGE FlexibleInstances #-}\n{-# LANGUAGE KindSignatures #-}\n{-# LANGUAGE ViewPatterns #-}\n{-# LANGUAGE Strict #-}"
  },
  {
    "path": "src/Duet/Renamer.hs",
    "chars": 24107,
    "preview": "{-# LANGUAGE MultiWayIf #-}\n{-# LANGUAGE ViewPatterns #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE FlexibleInst"
  },
  {
    "path": "src/Duet/Resolver.hs",
    "chars": 5996,
    "preview": "{-# LANGUAGE TupleSections #-}\n{-# LANGUAGE Strict #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE BangPatterns #-}\n{"
  },
  {
    "path": "src/Duet/Setup.hs",
    "chars": 12269,
    "preview": "{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE TupleSections #-}\n{-# LANGUAGE Overloade"
  },
  {
    "path": "src/Duet/Simple.hs",
    "chars": 3553,
    "preview": "{-# LANGUAGE TemplateHaskell #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE Rec"
  },
  {
    "path": "src/Duet/Stepper.hs",
    "chars": 14972,
    "preview": "{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE Strict #-}\n{-# OPTIONS_GHC -fno-warn-name-shadowing #-}\n{-# LANGUAGE Flexi"
  },
  {
    "path": "src/Duet/Supply.hs",
    "chars": 1378,
    "preview": "{-# LANGUAGE Strict #-}\n{-# LANGUAGE FlexibleContexts #-}\n-- |\n\nmodule Duet.Supply where\n\nimport Control.Monad.Catch\nimp"
  },
  {
    "path": "src/Duet/Tokenizer.hs",
    "chars": 12666,
    "preview": "{-# LANGUAGE BangPatterns #-}\n{-# LANGUAGE RankNTypes #-}\n{-# LANGUAGE TupleSections #-}\n{-# LANGUAGE FlexibleContexts #"
  },
  {
    "path": "src/Duet/Types.hs",
    "chars": 20040,
    "preview": "{-# LANGUAGE GADTs #-}\n{-# LANGUAGE TypeFamilies #-}\n{-# LANGUAGE DeriveDataTypeable #-}\n{-# LANGUAGE FlexibleInstances "
  },
  {
    "path": "stack.yaml",
    "chars": 20,
    "preview": "resolver: lts-20.20\n"
  },
  {
    "path": "test/Main.hs",
    "chars": 13036,
    "preview": "{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE RecordWildCards #-}\n{-# LANGUAGE TupleSections #-}\n{-# LANGUAGE Overloade"
  },
  {
    "path": "test/Spec.hs",
    "chars": 4200,
    "preview": "{-# LANGUAGE TemplateHaskell #-}\n{-# LANGUAGE ScopedTypeVariables #-}\n{-# LANGUAGE FlexibleContexts #-}\n{-# LANGUAGE Rec"
  }
]

About this extraction

This page contains the full source code of the chrisdone/duet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 53 files (236.5 KB), approximately 61.7k tokens. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!