{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}
module TH.Derive.Storable
( makeStorableInst
) where
import Control.Applicative
import Control.Monad
import Data.List (find)
import Data.Maybe (fromMaybe)
import Data.Word
import Foreign.Ptr
import Foreign.Storable
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import Prelude
import TH.Derive.Internal
import TH.ReifySimple
import TH.Utilities
instance Deriver (Storable a) where
runDeriver :: Proxy (Storable a) -> Cxt -> Type -> Q [Dec]
runDeriver _ = Cxt -> Type -> Q [Dec]
makeStorableInst
makeStorableInst :: Cxt -> Type -> Q [Dec]
makeStorableInst :: Cxt -> Type -> Q [Dec]
makeStorableInst preds :: Cxt
preds ty :: Type
ty = do
Type
argTy <- Name -> Type -> Q Type
expectTyCon1 ''Storable Type
ty
DataType
dt <- Type -> Q DataType
reifyDataTypeSubstituted Type
argTy
Cxt -> Type -> [DataCon] -> Q [Dec]
makeStorableImpl Cxt
preds Type
ty (DataType -> [DataCon]
dtCons DataType
dt)
makeStorableImpl :: Cxt -> Type -> [DataCon] -> Q [Dec]
makeStorableImpl :: Cxt -> Type -> [DataCon] -> Q [Dec]
makeStorableImpl preds :: Cxt
preds headTy :: Type
headTy cons :: [DataCon]
cons = do
Exp
alignmentMethod <- [| 1 |]
Exp
sizeOfMethod <- ExpQ
sizeExpr
Exp
peekMethod <- ExpQ
peekExpr
Exp
pokeMethod <- ExpQ
pokeExpr
let methods :: [Dec]
methods =
[ Name -> [Clause] -> Dec
FunD (String -> Name
mkName "alignment") [[Pat] -> Body -> [Dec] -> Clause
Clause [Pat
WildP] (Exp -> Body
NormalB Exp
alignmentMethod) []]
, Name -> [Clause] -> Dec
FunD (String -> Name
mkName "sizeOf") [[Pat] -> Body -> [Dec] -> Clause
Clause [Pat
WildP] (Exp -> Body
NormalB Exp
sizeOfMethod) []]
, Name -> [Clause] -> Dec
FunD (String -> Name
mkName "peek") [[Pat] -> Body -> [Dec] -> Clause
Clause [Name -> Pat
VarP Name
ptrName] (Exp -> Body
NormalB Exp
peekMethod) []]
, Name -> [Clause] -> Dec
FunD (String -> Name
mkName "poke") [[Pat] -> Body -> [Dec] -> Clause
Clause [Name -> Pat
VarP Name
ptrName, Name -> Pat
VarP Name
valName] (Exp -> Body
NormalB Exp
pokeMethod) []]
]
[Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Cxt -> Type -> [Dec] -> Dec
plainInstanceD Cxt
preds Type
headTy [Dec]
methods]
where
(tagType :: Name
tagType, _, tagSize :: Int
tagSize) =
(Name, Int, Int) -> Maybe (Name, Int, Int) -> (Name, Int, Int)
forall a. a -> Maybe a -> a
fromMaybe (String -> (Name, Int, Int)
forall a. HasCallStack => String -> a
error "Too many constructors") (Maybe (Name, Int, Int) -> (Name, Int, Int))
-> Maybe (Name, Int, Int) -> (Name, Int, Int)
forall a b. (a -> b) -> a -> b
$
((Name, Int, Int) -> Bool)
-> [(Name, Int, Int)] -> Maybe (Name, Int, Int)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(_, maxN :: Int
maxN, _) -> Int
maxN Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= [DataCon] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DataCon]
cons) [(Name, Int, Int)]
tagTypes
tagTypes :: [(Name, Int, Int)]
tagTypes :: [(Name, Int, Int)]
tagTypes =
[ ('(), 1, 0)
, (''Word8, Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8
forall a. Bounded a => a
maxBound :: Word8), 1)
, (''Word16, Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
forall a. Bounded a => a
maxBound :: Word16), 2)
, (''Word32, Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
forall a. Bounded a => a
maxBound :: Word32), 4)
, (''Word64, Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
forall a. Bounded a => a
maxBound :: Word64), 8)
]
valName :: Name
valName = String -> Name
mkName "val"
tagName :: Name
tagName = String -> Name
mkName "tag"
ptrName :: Name
ptrName = String -> Name
mkName "ptr"
fName :: a -> Name
fName ix :: a
ix = String -> Name
mkName ("f" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
ix)
ptrExpr :: ExpQ
ptrExpr = Name -> ExpQ
varE Name
ptrName
sizeExpr :: ExpQ
sizeExpr = ExpQ -> ExpQ -> ExpQ
appE (Name -> ExpQ
varE 'maximum) (ExpQ -> ExpQ) -> ExpQ -> ExpQ
forall a b. (a -> b) -> a -> b
$
[ExpQ] -> ExpQ
listE [ ExpQ -> ExpQ -> ExpQ
appE (Name -> ExpQ
varE 'sum) ([ExpQ] -> ExpQ
listE [Type -> ExpQ
sizeOfExpr Type
ty | (_, ty :: Type
ty) <- [(Maybe Name, Type)]
fields])
| (DataCon _ _ _ fields :: [(Maybe Name, Type)]
fields) <- [DataCon]
cons
]
peekExpr :: ExpQ
peekExpr = case [DataCon]
cons of
[] -> [| error ("Attempting to peek type with no constructors (" ++ $(lift (pprint headTy)) ++ ")") |]
[con :: DataCon
con] -> DataCon -> ExpQ
peekCon DataCon
con
_ -> [StmtQ] -> ExpQ
doE
[ PatQ -> ExpQ -> StmtQ
bindS (Name -> PatQ
varP Name
tagName) [| peek (castPtr $(ptrExpr)) |]
, ExpQ -> StmtQ
noBindS (ExpQ -> [MatchQ] -> ExpQ
caseE (ExpQ -> Q Type -> ExpQ
sigE (Name -> ExpQ
varE Name
tagName) (Name -> Q Type
conT Name
tagType))
(((Integer, DataCon) -> MatchQ) -> [(Integer, DataCon)] -> [MatchQ]
forall a b. (a -> b) -> [a] -> [b]
map (Integer, DataCon) -> MatchQ
peekMatch ([Integer] -> [DataCon] -> [(Integer, DataCon)]
forall a b. [a] -> [b] -> [(a, b)]
zip [0..] [DataCon]
cons) [MatchQ] -> [MatchQ] -> [MatchQ]
forall a. [a] -> [a] -> [a]
++ [MatchQ
peekErr]))
]
peekMatch :: (Integer, DataCon) -> MatchQ
peekMatch (ix :: Integer
ix, con :: DataCon
con) = PatQ -> BodyQ -> [DecQ] -> MatchQ
match (Lit -> PatQ
litP (Integer -> Lit
IntegerL Integer
ix)) (ExpQ -> BodyQ
normalB (DataCon -> ExpQ
peekCon DataCon
con)) []
peekErr :: MatchQ
peekErr = PatQ -> BodyQ -> [DecQ] -> MatchQ
match PatQ
wildP (ExpQ -> BodyQ
normalB [| error ("Found invalid tag while peeking (" ++ $(lift (pprint headTy)) ++ ")") |]) []
peekCon :: DataCon -> ExpQ
peekCon (DataCon cname :: Name
cname _ _ fields :: [(Maybe Name, Type)]
fields) =
[DecQ] -> ExpQ -> ExpQ
letE ([(Maybe Name, Type)] -> [DecQ]
forall a. [(a, Type)] -> [DecQ]
offsetDecls [(Maybe Name, Type)]
fields) (ExpQ -> ExpQ) -> ExpQ -> ExpQ
forall a b. (a -> b) -> a -> b
$
case [(Maybe Name, Type)]
fields of
[] -> [| pure $(conE cname) |]
(_:fields' :: [(Maybe Name, Type)]
fields') ->
(ExpQ -> (Int, (Maybe Name, Type)) -> ExpQ)
-> ExpQ -> [(Int, (Maybe Name, Type))] -> ExpQ
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\acc :: ExpQ
acc (ix :: Int
ix, _) -> [| $(acc) <*> $(peekOffset ix) |] )
[| $(conE cname) <$> $(peekOffset 0) |]
([Int] -> [(Maybe Name, Type)] -> [(Int, (Maybe Name, Type))]
forall a b. [a] -> [b] -> [(a, b)]
zip [1..] [(Maybe Name, Type)]
fields')
peekOffset :: Int -> ExpQ
peekOffset ix :: Int
ix = [| peek (castPtr (plusPtr $(ptrExpr) $(varE (offset ix)))) |]
pokeExpr :: ExpQ
pokeExpr = ExpQ -> [MatchQ] -> ExpQ
caseE (Name -> ExpQ
varE Name
valName) (((Int, DataCon) -> MatchQ) -> [(Int, DataCon)] -> [MatchQ]
forall a b. (a -> b) -> [a] -> [b]
map (Int, DataCon) -> MatchQ
pokeMatch ([Int] -> [DataCon] -> [(Int, DataCon)]
forall a b. [a] -> [b] -> [(a, b)]
zip [0..] [DataCon]
cons))
pokeMatch :: (Int, DataCon) -> MatchQ
pokeMatch :: (Int, DataCon) -> MatchQ
pokeMatch (ixcon :: Int
ixcon, DataCon cname :: Name
cname _ _ fields :: [(Maybe Name, Type)]
fields) =
PatQ -> BodyQ -> [DecQ] -> MatchQ
match (Name -> [PatQ] -> PatQ
conP Name
cname ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP ((Int -> Name) -> [Int] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Name
forall a. Show a => a -> Name
fName [Int]
ixs)))
(ExpQ -> BodyQ
normalB (case [StmtQ]
tagPokes [StmtQ] -> [StmtQ] -> [StmtQ]
forall a. [a] -> [a] -> [a]
++ [StmtQ]
offsetLet [StmtQ] -> [StmtQ] -> [StmtQ]
forall a. [a] -> [a] -> [a]
++ [StmtQ]
fieldPokes of
[] -> [|return ()|]
stmts :: [StmtQ]
stmts -> [StmtQ] -> ExpQ
doE [StmtQ]
stmts))
[]
where
tagPokes :: [StmtQ]
tagPokes = case [DataCon]
cons of
(_:_:_) -> [ExpQ -> StmtQ
noBindS [| poke (castPtr $(ptrExpr)) (ixcon :: $(conT tagType)) |]]
_ -> []
offsetLet :: [StmtQ]
offsetLet
| [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
ixs = []
| Bool
otherwise = [[DecQ] -> StmtQ
letS ([(Maybe Name, Type)] -> [DecQ]
forall a. [(a, Type)] -> [DecQ]
offsetDecls [(Maybe Name, Type)]
fields)]
fieldPokes :: [StmtQ]
fieldPokes = (Int -> StmtQ) -> [Int] -> [StmtQ]
forall a b. (a -> b) -> [a] -> [b]
map (ExpQ -> StmtQ
noBindS (ExpQ -> StmtQ) -> (Int -> ExpQ) -> Int -> StmtQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ExpQ
pokeField) [Int]
ixs
ixs :: [Int]
ixs = ((Int, (Maybe Name, Type)) -> Int)
-> [(Int, (Maybe Name, Type))] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, (Maybe Name, Type)) -> Int
forall a b. (a, b) -> a
fst ([Int] -> [(Maybe Name, Type)] -> [(Int, (Maybe Name, Type))]
forall a b. [a] -> [b] -> [(a, b)]
zip [0..] [(Maybe Name, Type)]
fields)
pokeField :: Int -> ExpQ
pokeField ix :: Int
ix = [| poke (castPtr (plusPtr $(ptrExpr)
$(varE (offset ix))))
$(varE (fName ix)) |]
offsetDecls :: [(a, Type)] -> [DecQ]
offsetDecls fields :: [(a, Type)]
fields =
[DecQ] -> [DecQ]
forall a. [a] -> [a]
init ([DecQ] -> [DecQ]) -> [DecQ] -> [DecQ]
forall a b. (a -> b) -> a -> b
$
((Int, ExpQ) -> DecQ) -> [(Int, ExpQ)] -> [DecQ]
forall a b. (a -> b) -> [a] -> [b]
map (\(ix :: Int
ix, expr :: ExpQ
expr) -> PatQ -> BodyQ -> [DecQ] -> DecQ
valD (Name -> PatQ
varP (Int -> Name
offset Int
ix)) (ExpQ -> BodyQ
normalB ExpQ
expr) []) ([(Int, ExpQ)] -> [DecQ]) -> [(Int, ExpQ)] -> [DecQ]
forall a b. (a -> b) -> a -> b
$
((0, [| tagSize |]) (Int, ExpQ) -> [(Int, ExpQ)] -> [(Int, ExpQ)]
forall a. a -> [a] -> [a]
:) ([(Int, ExpQ)] -> [(Int, ExpQ)]) -> [(Int, ExpQ)] -> [(Int, ExpQ)]
forall a b. (a -> b) -> a -> b
$
((Int, (a, Type)) -> (Int, ExpQ))
-> [(Int, (a, Type))] -> [(Int, ExpQ)]
forall a b. (a -> b) -> [a] -> [b]
map (\(ix :: Int
ix, (_, ty :: Type
ty)) -> (Int
ix, Int -> Type -> ExpQ
offsetExpr Int
ix Type
ty)) ([(Int, (a, Type))] -> [(Int, ExpQ)])
-> [(Int, (a, Type))] -> [(Int, ExpQ)]
forall a b. (a -> b) -> a -> b
$
[Int] -> [(a, Type)] -> [(Int, (a, Type))]
forall a b. [a] -> [b] -> [(a, b)]
zip [1..] [(a, Type)]
fields
where
offsetExpr :: Int -> Type -> ExpQ
offsetExpr ix :: Int
ix ty :: Type
ty = [| $(sizeOfExpr ty) + $(varE (offset (ix - 1))) |]
sizeOfExpr :: Type -> ExpQ
sizeOfExpr ty :: Type
ty = [| $(varE 'sizeOf) (error "sizeOf evaluated its argument" :: $(return ty)) |]
offset :: Int -> Name
offset ix :: Int
ix = String -> Name
mkName ("offset" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (Int
ix :: Int))