7c4657262bde486744ae5db9bd21c9586a59710e
[barrelfish] / tools / sockeye / SockeyeNetBuilder.hs
1 {-
2     SockeyeNetBuilder.hs: Decoding net builder for Sockeye
3
4     Part of Sockeye
5
6     Copyright (c) 2017, ETH Zurich.
7
8     All rights reserved.
9
10     This file is distributed under the terms in the attached LICENSE file.
11     If you do not find this file, copies can be found by writing to:
12     ETH Zurich D-INFK, CAB F.78, Universitaetstr. 6, CH-8092 Zurich,
13     Attn: Systems Group.
14 -}
15
16 {-# LANGUAGE MultiParamTypeClasses #-}
17 {-# LANGUAGE FlexibleInstances #-}
18 {-# LANGUAGE FlexibleContexts #-}
19
20 module SockeyeNetBuilder
21 ( buildSockeyeNet ) where
22
23 import Control.Monad.State
24
25 import Data.Either
26 import Data.List (nub, intercalate, sort)
27 import Data.Map (Map)
28 import qualified Data.Map as Map
29 import Data.Maybe (fromMaybe)
30 import Data.Set (Set)
31 import qualified Data.Set as Set
32
33 import SockeyeChecks
34
35 import qualified SockeyeASTInstantiator as InstAST
36 import qualified SockeyeASTDecodingNet as NetAST
37
38 import Debug.Trace
39
40 data NetBuildFail
41     = UndefinedOutPort   !String !String
42     | UndefinedInPort    !String !String
43     | UndefinedReference !String !String
44
45 instance Show NetBuildFail where
46     show (UndefinedInPort  inst port)  = concat ["Mapping to undefined input port '",   port, "' in module instantiation '", inst, "'"]
47     show (UndefinedOutPort inst port)  = concat ["Mapping to undefined output port '",  port, "' in module instantiation '", inst, "'"]
48     show (UndefinedReference context ident) = concat ["Reference to undefined node '", ident, "' in ", context]
49
50 type PortMap = Map InstAST.Identifier NetAST.NodeId
51
52 data Context = Context
53     { modules      :: Map InstAST.Identifier InstAST.Module
54     , curModule    :: !String
55     , curNamespace :: [String]
56     , curNode      :: !String
57     , inPortMap    :: PortMap
58     , outPortMap   :: PortMap
59     , nodes        :: Set String
60     , mappedBlocks :: [InstAST.BlockSpec]
61     }
62
63 buildSockeyeNet :: InstAST.SockeyeSpec -> Either (FailedChecks NetBuildFail) NetAST.NetSpec
64 buildSockeyeNet ast = do
65     let
66         context = Context
67             { modules      = Map.empty
68             , curModule    = ""
69             , curNamespace = []
70             , curNode      = ""
71             , inPortMap    = Map.empty
72             , outPortMap   = Map.empty
73             , nodes        = Set.empty
74             , mappedBlocks = []
75             }        
76     net <- runChecks $ transform context ast
77     return net
78
79 --            
80 -- Build net
81 --
82 class NetTransformable a b where
83     transform :: Context -> a -> Checks NetBuildFail b
84
85 instance NetTransformable InstAST.SockeyeSpec NetAST.NetSpec where
86     transform context ast = do
87         let
88             rootInst = InstAST.root ast
89             mods = InstAST.modules ast
90             specContext = context
91                 { modules = mods }
92         transform specContext rootInst
93
94 instance NetTransformable InstAST.Module NetAST.NetSpec where
95     transform context ast = do
96         let inPorts = InstAST.inputPorts ast
97             outPorts = InstAST.outputPorts ast
98             moduleInsts = InstAST.moduleInsts ast
99             nodeDecls = InstAST.nodeDecls ast
100             outPortIds = map InstAST.portId outPorts
101             inMapIds = concatMap Map.elems $ map InstAST.inPortMap moduleInsts
102             declIds = map InstAST.nodeId nodeDecls
103             modContext = context
104                 { nodes = Set.fromList $ outPortIds ++ inMapIds ++ declIds }
105         -- TODO: check mapping to undefined port
106         inPortDecls <- transform modContext inPorts
107         outPortDecls <- transform modContext outPorts
108         netDecls <- transform modContext nodeDecls
109         netInsts <- transform modContext moduleInsts     
110         return $ Map.unions (inPortDecls ++ outPortDecls ++ netDecls ++ netInsts)
111
112 instance NetTransformable InstAST.Port NetAST.NetSpec where
113     transform context ast@(InstAST.InputPort {}) = do
114         let portId = InstAST.portId ast
115             portWidth = InstAST.portWidth ast
116             portMap = inPortMap context
117             mappedId = Map.lookup portId portMap
118         netPortId <- transform context portId
119         case mappedId of
120             Nothing    -> return Map.empty
121             Just ident -> do
122                 let node = portNode netPortId portWidth
123                 return $ Map.fromList [(ident, node)]
124     transform context ast@(InstAST.OutputPort {}) = do
125         let portId = InstAST.portId ast
126             portWidth = InstAST.portWidth ast
127             portMap = outPortMap context
128             mappedId = Map.lookup portId portMap
129         netPortId <- transform context portId
130         case mappedId of
131             Nothing    -> return $ Map.fromList [(netPortId, portNodeTemplate)]
132             Just ident -> do
133                 let node = portNode ident portWidth
134                 return $ Map.fromList $ [(netPortId, node)]
135
136 portNode :: NetAST.NodeId -> Integer -> NetAST.NodeSpec
137 portNode destNode width =
138     let base = 0
139         limit = 2^width - 1
140         srcBlock = NetAST.BlockSpec
141             { NetAST.base  = base
142             , NetAST.limit = limit
143             }
144         map = NetAST.MapSpec
145                 { NetAST.srcBlock = srcBlock
146                 , NetAST.destNode = destNode
147                 , NetAST.destBase = base
148                 }
149     in portNodeTemplate { NetAST.translate = [map] }
150
151 portNodeTemplate :: NetAST.NodeSpec
152 portNodeTemplate = NetAST.NodeSpec
153     { NetAST.nodeType  = NetAST.Other
154     , NetAST.accept    = []
155     , NetAST.translate = []
156     }
157
158 instance NetTransformable InstAST.ModuleInst NetAST.NetSpec where
159     transform context ast = do
160         let name = InstAST.moduleName ast
161             namespace = InstAST.namespace ast
162             inPortMap = InstAST.inPortMap ast
163             outPortMap = InstAST.outPortMap ast
164             mod = (modules context) Map.! name
165             errorContext = concat ["port mapping for '", name, " as ", namespace, "'"]
166         mapM_ (checkReference context $ UndefinedReference errorContext) $ (Map.elems inPortMap) ++ (Map.elems outPortMap)
167         netInMap <- transform context inPortMap
168         netOutMap <- transform context outPortMap
169         let instContext = context
170                 { curModule    = name
171                 , curNamespace = namespace:(curNamespace context)
172                 , inPortMap    = netInMap
173                 , outPortMap   = netOutMap
174                 }
175         transform instContext mod
176
177 instance NetTransformable InstAST.NodeDecl NetAST.NetSpec where
178     transform context ast = do
179         let nodeId = InstAST.nodeId ast
180             nodeSpec = InstAST.nodeSpec ast
181             nodeContext = context
182                 { curNode = nodeId }
183         netNodeId <- transform context nodeId
184         netNodeSpec <- transform nodeContext nodeSpec
185         return $ Map.fromList [(netNodeId, netNodeSpec)]
186
187 instance NetTransformable InstAST.Identifier NetAST.NodeId where
188     transform context ast = do
189         let namespace = curNamespace context
190         return NetAST.NodeId
191             { NetAST.namespace = namespace
192             , NetAST.name      = ast
193             }
194
195 instance NetTransformable InstAST.NodeSpec NetAST.NodeSpec where
196     transform context ast = do
197         let
198             nodeType = InstAST.nodeType ast
199             accept = InstAST.accept ast
200             translate = InstAST.translate ast
201             reserved = InstAST.reserved ast
202             overlay = InstAST.overlay ast
203         netTranslate <- transform context translate
204         let
205             mapBlocks = map NetAST.srcBlock netTranslate
206             nodeContext = context
207                 { mappedBlocks = accept ++ mapBlocks ++ reserved }
208         netOverlay <- case overlay of
209                 Nothing -> return []
210                 Just o  -> transform nodeContext o
211         return NetAST.NodeSpec
212             { NetAST.nodeType  = nodeType
213             , NetAST.accept    = accept
214             , NetAST.translate = netTranslate ++ netOverlay
215             }
216
217 instance NetTransformable InstAST.MapSpec NetAST.MapSpec where
218     transform context ast = do
219         let
220             srcBlock = InstAST.srcBlock ast
221             destNode = InstAST.destNode ast
222             destBase = InstAST.destBase ast
223             errorContext = "tranlate set of node '" ++ curNode context ++ "'"
224         checkReference context (UndefinedReference errorContext) destNode
225         netDestNode <- transform context destNode
226         return NetAST.MapSpec
227             { NetAST.srcBlock = srcBlock
228             , NetAST.destNode = netDestNode
229             , NetAST.destBase = destBase
230             }
231
232 instance NetTransformable InstAST.OverlaySpec [NetAST.MapSpec] where
233     transform context ast = do
234         let
235             over = InstAST.over ast
236             width = InstAST.width ast
237             blocks = mappedBlocks context
238             errorContext = "overlay of node '" ++ curNode context ++ "'"
239         checkReference context (UndefinedReference errorContext) over
240         netOver <- transform context over
241         let maps = overlayMaps netOver width blocks
242         return maps
243
244 overlayMaps :: NetAST.NodeId -> Integer -> [NetAST.BlockSpec] -> [NetAST.MapSpec]
245 overlayMaps destId width blocks =
246     let
247         blockPoints = concat $ map toScanPoints blocks
248         maxAddress = 2^width
249         overStop  = BlockStart $ maxAddress
250         scanPoints = filter ((maxAddress >=) . address) $ sort (overStop:blockPoints)
251         startState = ScanLineState
252             { insideBlocks    = 0
253             , startAddress    = 0
254             }
255     in evalState (scanLine scanPoints []) startState
256     where
257         toScanPoints (NetAST.BlockSpec base limit) =
258                 [ BlockStart base
259                 , BlockEnd   limit
260                 ]
261         scanLine [] ms = return ms
262         scanLine (p:ps) ms = do
263             maps <- pointAction p ms
264             scanLine ps maps
265         pointAction (BlockStart a) ms = do
266             s <- get       
267             let
268                 i = insideBlocks s
269                 base = startAddress s
270                 limit = a - 1
271             maps <- if (i == 0) && (base <= limit)
272                 then
273                     let
274                         baseAddress = startAddress s
275                         limitAddress = a - 1
276                         srcBlock = NetAST.BlockSpec baseAddress limitAddress
277                         m = NetAST.MapSpec srcBlock destId baseAddress
278                     in return $ m:ms
279                 else return ms
280             modify (\s -> s { insideBlocks = i + 1})
281             return maps
282         pointAction (BlockEnd a) ms = do
283             s <- get
284             let
285                 i = insideBlocks s
286             put $ ScanLineState (i - 1) (a + 1)
287             return ms
288
289 data StoppingPoint
290     = BlockStart { address :: !NetAST.Address }
291     | BlockEnd   { address :: !NetAST.Address }
292     deriving (Eq, Show)
293
294 instance Ord StoppingPoint where
295     (<=) (BlockStart a1) (BlockEnd   a2)
296         | a1 == a2 = True
297         | otherwise = a1 <= a2
298     (<=) (BlockEnd   a1) (BlockStart a2)
299         | a1 == a2 = False
300         | otherwise = a1 <= a2
301     (<=) sp1 sp2 = (address sp1) <= (address sp2)
302
303 data ScanLineState
304     = ScanLineState
305         { insideBlocks :: !Integer
306         , startAddress :: !NetAST.Address
307         } deriving (Show)
308
309 instance (Traversable t, NetTransformable a b) => NetTransformable (t a)  (t b) where
310     transform context as = mapM (transform context) as
311
312 checkReference :: Context -> (String -> NetBuildFail) -> String -> (Checks NetBuildFail) ()
313 checkReference context fail name =
314     if name `Set.member` (nodes context)
315         then return ()
316         else failCheck (curModule context) (fail name)