64acf36063f8d7e80da46a30d2e1b5b1aaa3cbfd
[barrelfish] / tools / sockeye / SockeyeNetBuilder.hs
1 {-
2     SockeyeNetBuilder.hs: Decoding net builder for Sockeye
3
4     Part of Sockeye
5
6     Copyright (c) 2017, ETH Zurich.
7
8     All rights reserved.
9
10     This file is distributed under the terms in the attached LICENSE file.
11     If you do not find this file, copies can be found by writing to:
12     ETH Zurich D-INFK, CAB F.78, Universitaetstr. 6, CH-8092 Zurich,
13     Attn: Systems Group.
14 -}
15
16 {-# LANGUAGE MultiParamTypeClasses #-}
17 {-# LANGUAGE FlexibleInstances #-}
18 {-# LANGUAGE FlexibleContexts #-}
19
20 module SockeyeNetBuilder
21 ( buildSockeyeNet ) where
22
23 import Control.Monad.State
24
25 import Data.Either
26 import Data.List (nub, intercalate, sort)
27 import Data.Map (Map)
28 import qualified Data.Map as Map
29 import Data.Maybe (fromMaybe)
30 import Data.Set (Set)
31 import qualified Data.Set as Set
32
33 import SockeyeChecks
34
35 import qualified SockeyeASTInstantiator as InstAST
36 import qualified SockeyeASTDecodingNet as NetAST
37
38 import Debug.Trace
39
40 data NetBuildFail
41     = UndefinedOutPort !String !String
42     | UndefinedInPort  !String !String
43     | UndefinedRefPort !String
44     | UndefinedRefNode !String !String
45
46 instance Show NetBuildFail where
47     show (UndefinedInPort  inst port)  = concat ["Mapping to undefined input port '",   port, "' in module instantiation '", inst, "'"]
48     show (UndefinedOutPort inst port)  = concat ["Mapping to undefined output port '",  port, "' in module instantiation '", inst, "'"]
49     show (UndefinedRefPort      port)  = concat ["Input port '", port, "' declared but corresponding node not defined"]
50     show (UndefinedRefNode context ident) = concat ["Reference to undefined node '", ident, "' in ", context]
51
52 type PortMap = Map InstAST.Identifier NetAST.NodeId
53
54 data Context = Context
55     { modules      :: Map InstAST.Identifier InstAST.Module
56     , curModule    :: !String
57     , curNamespace :: [String]
58     , curNode      :: !String
59     , inPortMap    :: PortMap
60     , outPortMap   :: PortMap
61     , nodes        :: Set String
62     , mappedBlocks :: [InstAST.BlockSpec]
63     }
64
65 buildSockeyeNet :: InstAST.SockeyeSpec -> Either (FailedChecks NetBuildFail) NetAST.NetSpec
66 buildSockeyeNet ast = do
67     let
68         context = Context
69             { modules      = Map.empty
70             , curModule    = ""
71             , curNamespace = []
72             , curNode      = ""
73             , inPortMap    = Map.empty
74             , outPortMap   = Map.empty
75             , nodes        = Set.empty
76             , mappedBlocks = []
77             }        
78     net <- runChecks $ transform context ast
79     return net
80
81 --            
82 -- Build net
83 --
84 class NetTransformable a b where
85     transform :: Context -> a -> Checks NetBuildFail b
86
87 instance NetTransformable InstAST.SockeyeSpec NetAST.NetSpec where
88     transform context ast = do
89         let
90             rootInst = InstAST.root ast
91             mods = InstAST.modules ast
92             specContext = context
93                 { modules = mods }
94         transform specContext rootInst
95
96 instance NetTransformable InstAST.Module NetAST.NetSpec where
97     transform context ast = do
98         let inPorts = InstAST.inputPorts ast
99             outPorts = InstAST.outputPorts ast
100             moduleInsts = InstAST.moduleInsts ast
101             nodeDecls = InstAST.nodeDecls ast
102             outPortIds = map InstAST.portId outPorts
103             inMapIds = concatMap Map.elems $ map InstAST.inPortMap moduleInsts
104             declIds = map InstAST.nodeId nodeDecls
105             modContext = context
106                 { nodes = Set.fromList $ outPortIds ++ inMapIds ++ declIds }
107         inPortDecls <- transform modContext inPorts
108         outPortDecls <- transform modContext outPorts
109         netDecls <- transform modContext nodeDecls
110         netInsts <- transform modContext moduleInsts     
111         return $ Map.unions (inPortDecls ++ outPortDecls ++ netDecls ++ netInsts)
112
113 instance NetTransformable InstAST.Port NetAST.NetSpec where
114     transform context ast@(InstAST.InputPort {}) = do
115         let portId = InstAST.portId ast
116             portWidth = InstAST.portWidth ast
117             portMap = inPortMap context
118             mappedId = Map.lookup portId portMap
119         netPortId <- transform context portId
120         case mappedId of
121             Nothing    -> return Map.empty
122             Just ident -> do
123                 let node = portNode netPortId portWidth
124                 return $ Map.fromList [(ident, node)]
125     transform context ast@(InstAST.OutputPort {}) = do
126         let portId = InstAST.portId ast
127             portWidth = InstAST.portWidth ast
128             portMap = outPortMap context
129             mappedId = Map.lookup portId portMap
130         netPortId <- transform context portId
131         case mappedId of
132             Nothing    -> return $ Map.fromList [(netPortId, portNodeTemplate)]
133             Just ident -> do
134                 let node = portNode ident portWidth
135                 return $ Map.fromList $ [(netPortId, node)]
136
137 portNode :: NetAST.NodeId -> Integer -> NetAST.NodeSpec
138 portNode destNode width =
139     let base = 0
140         limit = 2^width - 1
141         srcBlock = NetAST.BlockSpec
142             { NetAST.base  = base
143             , NetAST.limit = limit
144             }
145         map = NetAST.MapSpec
146                 { NetAST.srcBlock = srcBlock
147                 , NetAST.destNode = destNode
148                 , NetAST.destBase = base
149                 }
150     in portNodeTemplate { NetAST.translate = [map] }
151
152 portNodeTemplate :: NetAST.NodeSpec
153 portNodeTemplate = NetAST.NodeSpec
154     { NetAST.nodeType  = NetAST.Other
155     , NetAST.accept    = []
156     , NetAST.translate = []
157     }
158
159 instance NetTransformable InstAST.ModuleInst NetAST.NetSpec where
160     transform context ast = do
161         let name = InstAST.moduleName ast
162             namespace = InstAST.namespace ast
163             inPortMap = InstAST.inPortMap ast
164             outPortMap = InstAST.outPortMap ast
165             mod = (modules context) Map.! name
166         netInMap <- transform context inPortMap
167         netOutMap <- transform context outPortMap
168         let instContext = context
169                 { curModule    = name
170                 , curNamespace = namespace:(curNamespace context)
171                 , inPortMap    = netInMap
172                 , outPortMap   = netOutMap
173                 }
174         transform instContext mod
175
176 instance NetTransformable InstAST.NodeDecl NetAST.NetSpec where
177     transform context ast = do
178         let nodeId = InstAST.nodeId ast
179             nodeSpec = InstAST.nodeSpec ast
180             nodeContext = context
181                 { curNode = nodeId }
182         netNodeId <- transform context nodeId
183         netNodeSpec <- transform nodeContext nodeSpec
184         return $ Map.fromList [(netNodeId, netNodeSpec)]
185
186 instance NetTransformable InstAST.Identifier NetAST.NodeId where
187     transform context ast = do
188         let namespace = curNamespace context
189         return NetAST.NodeId
190             { NetAST.namespace = namespace
191             , NetAST.name      = ast
192             }
193
194 instance NetTransformable InstAST.NodeSpec NetAST.NodeSpec where
195     transform context ast = do
196         let
197             nodeType = InstAST.nodeType ast
198             accept = InstAST.accept ast
199             translate = InstAST.translate ast
200             reserved = InstAST.reserved ast
201             overlay = InstAST.overlay ast
202         netTranslate <- transform context translate
203         let
204             mapBlocks = map NetAST.srcBlock netTranslate
205             nodeContext = context
206                 { mappedBlocks = accept ++ mapBlocks ++ reserved }
207         netOverlay <- case overlay of
208                 Nothing -> return []
209                 Just o  -> transform nodeContext o
210         return NetAST.NodeSpec
211             { NetAST.nodeType  = nodeType
212             , NetAST.accept    = accept
213             , NetAST.translate = netTranslate ++ netOverlay
214             }
215
216 instance NetTransformable InstAST.MapSpec NetAST.MapSpec where
217     transform context ast = do
218         let
219             srcBlock = InstAST.srcBlock ast
220             destNode = InstAST.destNode ast
221             destBase = InstAST.destBase ast
222             errorContext = "tranlate set of node '" ++ curNode context ++ "'"
223         checkReference context (UndefinedRefNode errorContext) destNode
224         netDestNode <- transform context destNode
225         return NetAST.MapSpec
226             { NetAST.srcBlock = srcBlock
227             , NetAST.destNode = netDestNode
228             , NetAST.destBase = destBase
229             }
230
231 instance NetTransformable InstAST.OverlaySpec [NetAST.MapSpec] where
232     transform context ast = do
233         let
234             over = InstAST.over ast
235             width = InstAST.width ast
236             blocks = mappedBlocks context
237             errorContext = "overlay of node '" ++ curNode context ++ "'"
238         checkReference context (UndefinedRefNode errorContext) over
239         netOver <- transform context over
240         let maps = overlayMaps netOver width blocks
241         return maps
242
243 overlayMaps :: NetAST.NodeId -> Integer -> [NetAST.BlockSpec] -> [NetAST.MapSpec]
244 overlayMaps destId width blocks =
245     let
246         blockPoints = concat $ map toScanPoints blocks
247         maxAddress = 2^width
248         overStop  = BlockStart $ maxAddress
249         scanPoints = filter ((maxAddress >=) . address) $ sort (overStop:blockPoints)
250         startState = ScanLineState
251             { insideBlocks    = 0
252             , startAddress    = 0
253             }
254     in evalState (scanLine scanPoints []) startState
255     where
256         toScanPoints (NetAST.BlockSpec base limit) =
257                 [ BlockStart base
258                 , BlockEnd   limit
259                 ]
260         scanLine [] ms = return ms
261         scanLine (p:ps) ms = do
262             maps <- pointAction p ms
263             scanLine ps maps
264         pointAction (BlockStart a) ms = do
265             s <- get       
266             let
267                 i = insideBlocks s
268                 base = startAddress s
269                 limit = a - 1
270             maps <- if (i == 0) && (base <= limit)
271                 then
272                     let
273                         baseAddress = startAddress s
274                         limitAddress = a - 1
275                         srcBlock = NetAST.BlockSpec baseAddress limitAddress
276                         m = NetAST.MapSpec srcBlock destId baseAddress
277                     in return $ m:ms
278                 else return ms
279             modify (\s -> s { insideBlocks = i + 1})
280             return maps
281         pointAction (BlockEnd a) ms = do
282             s <- get
283             let
284                 i = insideBlocks s
285             put $ ScanLineState (i - 1) (a + 1)
286             return ms
287
288 data StoppingPoint
289     = BlockStart { address :: !NetAST.Address }
290     | BlockEnd   { address :: !NetAST.Address }
291     deriving (Eq, Show)
292
293 instance Ord StoppingPoint where
294     (<=) (BlockStart a1) (BlockEnd   a2)
295         | a1 == a2 = True
296         | otherwise = a1 <= a2
297     (<=) (BlockEnd   a1) (BlockStart a2)
298         | a1 == a2 = False
299         | otherwise = a1 <= a2
300     (<=) sp1 sp2 = (address sp1) <= (address sp2)
301
302 data ScanLineState
303     = ScanLineState
304         { insideBlocks :: !Integer
305         , startAddress :: !NetAST.Address
306         } deriving (Show)
307
308 instance (Traversable t, NetTransformable a b) => NetTransformable (t a)  (t b) where
309     transform context as = mapM (transform context) as
310
311 checkReference :: Context -> (String -> NetBuildFail) -> String -> (Checks NetBuildFail) ()
312 checkReference context fail name =
313     if name `Set.member` (nodes context)
314         then return ()
315         else failCheck (curModule context) (fail name)