Sockeye: Add possibility to add reserved blocks
[barrelfish] / tools / sockeye / SockeyeNetBuilder.hs
1 {-
2     SockeyeNetBuilder.hs: Decoding net builder for Sockeye
3
4     Part of Sockeye
5
6     Copyright (c) 2017, ETH Zurich.
7
8     All rights reserved.
9
10     This file is distributed under the terms in the attached LICENSE file.
11     If you do not find this file, copies can be found by writing to:
12     ETH Zurich D-INFK, CAB F.78, Universitaetstr. 6, CH-8092 Zurich,
13     Attn: Systems Group.
14 -}
15
16 {-# LANGUAGE MultiParamTypeClasses #-}
17 {-# LANGUAGE FlexibleInstances #-}
18 {-# LANGUAGE FlexibleContexts #-}
19
20 module SockeyeNetBuilder
21 ( sockeyeBuildNet ) where
22
23 import Control.Monad.State
24
25 import Data.Either
26 import Data.List (nub, intercalate, sort)
27 import Data.Map (Map)
28 import qualified Data.Map as Map
29 import Data.Maybe (fromMaybe)
30 import Data.Set (Set)
31 import qualified Data.Set as Set
32
33 import qualified SockeyeAST as AST
34 import qualified SockeyeASTDecodingNet as NetAST
35
36 type NetNodeDecl = (NetAST.NodeId, NetAST.NodeSpec)
37 type NetList = [NetNodeDecl]
38 type PortMap = [(String, NetAST.NodeId)]
39
40 data FailedCheck
41     = ModuleInstLoop [String]
42     | DuplicateInPort !String !String
43     | DuplicateInMap !String !String
44     | UndefinedInPort !String !String
45     | DuplicateOutPort !String !String
46     | DuplicateOutMap !String !String
47     | UndefinedOutPort !String !String
48     | DuplicateIdentifer !String
49     | UndefinedReference !String
50
51 instance Show FailedCheck where
52     show (ModuleInstLoop loop) = concat ["Module instantiation loop:'", intercalate "' -> '" loop, "'"]
53     show (DuplicateInPort  modName port) = concat ["Multiple declarations of input port '", port, "' in '", modName, "'"]
54     show (DuplicateInMap   ns      port) = concat ["Multiple mappings for input port '", port, "' in '", ns, "'"]
55     show (UndefinedInPort  modName port) = concat ["'", port, "' is not an input port in '", modName, "'"]
56     show (DuplicateOutPort modName port) = concat ["Multiple declarations of output port '", port, "' in '", modName, "'"]
57     show (DuplicateOutMap   ns      port) = concat ["Multiple mappings for output port '", port, "' in '", ns, "'"]
58     show (UndefinedOutPort modName port) = concat ["'", port, "' is not an output port in '", modName, "'"]
59     show (DuplicateIdentifer ident)   = concat ["Multiple declarations of node '", show ident, "'"]
60     show (UndefinedReference ident)   = concat ["Reference to undefined node '", show ident, "'"]
61
62 newtype CheckFailure = CheckFailure
63     { failures :: [FailedCheck] }
64
65 instance Show CheckFailure where
66     show (CheckFailure fs) = unlines $ "":(map show fs)
67
68 data Context = Context
69     { spec         :: AST.SockeyeSpec
70     , modulePath   :: [String]
71     , curNamespace :: NetAST.Namespace
72     , paramValues  :: Map String Integer
73     , varValues    :: Map String Integer
74     , inPortMaps   :: Map String NetAST.NodeId
75     , outPortMaps  :: Map String NetAST.NodeId
76     , mappedBlocks :: [NetAST.BlockSpec]
77     }
78
79 sockeyeBuildNet :: AST.SockeyeSpec -> Either CheckFailure NetAST.NetSpec
80 sockeyeBuildNet ast = do
81     let
82         context = Context
83             { spec         = AST.SockeyeSpec Map.empty
84             , modulePath   = []
85             , curNamespace = NetAST.Namespace []
86             , paramValues  = Map.empty
87             , varValues    = Map.empty
88             , inPortMaps   = Map.empty
89             , outPortMaps  = Map.empty
90             , mappedBlocks = []
91             }        
92     net <- transform context ast
93     check Set.empty net
94     return net
95 --            
96 -- Build net
97 --
98 class NetTransformable a b where
99     transform :: Context -> a -> Either CheckFailure b
100
101 instance NetTransformable AST.SockeyeSpec NetAST.NetSpec where
102     transform context ast = do
103         let
104             rootInst = AST.ModuleInst
105                 { AST.namespace  = AST.SimpleIdent ""
106                 , AST.moduleName = "@root"
107                 , AST.arguments  = Map.empty
108                 , AST.inPortMap  = []
109                 , AST.outPortMap = []
110                 }
111             specContext = context
112                 { spec = ast }
113         netList <- transform specContext rootInst
114         let
115             nodeIds = map fst netList
116         checkDuplicates nodeIds DuplicateIdentifer
117         let
118             nodeMap = Map.fromList netList
119         return $ NetAST.NetSpec nodeMap
120
121 instance NetTransformable AST.Module NetList where
122     transform context ast = do
123         let
124             inPorts = AST.inputPorts ast
125             outPorts = AST.outputPorts ast
126             nodeDecls = AST.nodeDecls ast
127             moduleInsts = AST.moduleInsts ast
128         inDecls <- do
129             net <- transform context inPorts
130             return $ concat (net :: [NetList])
131         outDecls <- do
132             net <- transform context outPorts
133             return $ concat (net :: [NetList])
134         -- TODO check duplicate ports
135         -- TODO check mappings to non existing port
136         netDecls <- transform context nodeDecls
137         netInsts <- transform context moduleInsts
138         return $ concat (inDecls:outDecls:netDecls ++ netInsts)            
139
140 instance NetTransformable AST.Port NetList where
141     transform context (AST.MultiPort for) = do
142         netPorts <- transform context for
143         return $ concat (netPorts :: [NetList])
144     transform context (AST.InputPort portId portWidth) = do
145         netPortId <- transform context portId
146         let
147             portMap = inPortMaps context
148             name = NetAST.name netPortId
149             mappedId = Map.lookup name portMap
150         case mappedId of
151             Nothing    -> return []
152             Just ident -> do
153                 let
154                     node = portNode netPortId portWidth
155                 return [(ident, node)]
156     transform context (AST.OutputPort portId portWidth) = do
157         netPortId <- transform context portId
158         let
159             portMap = outPortMaps context
160             name = NetAST.name netPortId
161             mappedId = Map.lookup name portMap
162         case mappedId of
163             Nothing    -> return [(netPortId, portNodeTemplate)]
164             Just ident -> do
165                 let
166                     node = portNode ident portWidth
167                 return [(netPortId, node)]
168
169 portNode :: NetAST.NodeId -> Integer -> NetAST.NodeSpec
170 portNode destNode width =
171     let
172         base = NetAST.Address 0
173         limit = NetAST.Address $ 2^width - 1
174         srcBlock = NetAST.BlockSpec
175             { NetAST.base  = base
176             , NetAST.limit = limit
177             }
178         map = NetAST.MapSpec
179                 { NetAST.srcBlock = srcBlock
180                 , NetAST.destNode = destNode
181                 , NetAST.destBase = base
182                 }
183     in portNodeTemplate { NetAST.translate = [map] }
184
185 portNodeTemplate :: NetAST.NodeSpec
186 portNodeTemplate = NetAST.NodeSpec
187     { NetAST.nodeType  = NetAST.Other
188     , NetAST.accept    = []
189     , NetAST.translate = []
190     }    
191
192 instance NetTransformable AST.ModuleInst NetList where
193     transform context (AST.MultiModuleInst for) = do
194         net <- transform context for
195         return $ concat (net :: [NetList])
196     transform context ast = do
197         let
198             namespace = AST.namespace ast
199             name = AST.moduleName ast
200             args = AST.arguments ast
201             inPortMap = AST.inPortMap ast
202             outPortMap = AST.outPortMap ast
203             mod = getModule context name
204         checkSelfInst name
205         netNamespace <- transform context namespace
206         netArgs <- transform context args
207         netInMap <- transform context inPortMap
208         netOutMap <- transform context outPortMap
209         let
210             inMaps = concat (netInMap :: [PortMap])
211             outMaps = concat (netOutMap :: [PortMap])
212         checkDuplicates (map fst inMaps) (DuplicateInMap $ show netNamespace) 
213         checkDuplicates (map fst outMaps) (DuplicateOutMap $ show netNamespace)
214         let
215             modContext = moduleContext name netNamespace netArgs inMaps outMaps
216         transform modContext mod
217             where
218                 moduleContext name namespace args inMaps outMaps =
219                     let
220                         path = modulePath context
221                         base = NetAST.ns $ NetAST.namespace namespace
222                         newNs = case NetAST.name namespace of
223                             "" -> NetAST.Namespace base
224                             n  -> NetAST.Namespace $ n:base
225                     in context
226                         { modulePath   = name:path
227                         , curNamespace = newNs
228                         , paramValues  = args
229                         , varValues    = Map.empty
230                         , inPortMaps   = Map.fromList inMaps
231                         , outPortMaps  = Map.fromList outMaps
232                         }
233                 checkSelfInst name = do
234                     let
235                         path = modulePath context
236                     case loop path of
237                         [] -> return ()
238                         l  -> Left $ CheckFailure [ModuleInstLoop (reverse $ name:l)]
239                         where
240                             loop [] = []
241                             loop path@(p:ps)
242                                 | name `elem` path = p:(loop ps)
243                                 | otherwise = []
244
245
246 instance NetTransformable AST.PortMap PortMap where
247     transform context (AST.MultiPortMap for) = do
248         ts <- transform context for
249         return $ concat (ts :: [PortMap])
250     transform context ast = do
251         let
252             mappedId = AST.mappedId ast
253             mappedPort = AST.mappedPort ast
254         netMappedId <- transform context mappedId
255         netMappedPort <- transform context mappedPort
256         return [(NetAST.name netMappedPort, netMappedId)]
257
258 instance NetTransformable AST.ModuleArg Integer where
259     transform _ (AST.AddressArg value) = return value
260     transform _ (AST.NaturalArg value) = return value
261     transform context (AST.ParamArg name) = return $ getParamValue context name
262
263 instance NetTransformable AST.Identifier NetAST.NodeId where
264     transform context ast = do
265         let
266             namespace = curNamespace context
267             name = identName ast
268         return NetAST.NodeId
269             { NetAST.namespace = namespace
270             , NetAST.name      = name
271             }
272             where
273                 identName (AST.SimpleIdent name) = name
274                 identName ident =
275                     let
276                         prefix = AST.prefix ident
277                         varName = AST.varName ident
278                         suffix = AST.suffix ident
279                         varValue = show $ getVarValue context varName
280                         suffixName = case suffix of
281                             Nothing -> ""
282                             Just s  -> identName s
283                     in prefix ++ varValue ++ suffixName
284
285 instance NetTransformable AST.NodeDecl NetList where
286     transform context (AST.MultiNodeDecl for) = do
287         ts <- transform context for
288         return $ concat (ts :: [NetList])
289     transform context ast = do
290         let
291             ident = AST.nodeId ast
292             nodeSpec = AST.nodeSpec ast
293         nodeId <- transform context ident
294         netNodeSpec <- transform context nodeSpec
295         return [(nodeId, netNodeSpec)]
296
297 instance NetTransformable AST.NodeSpec NetAST.NodeSpec where
298     transform context ast = do
299         let
300             nodeType = AST.nodeType ast
301             accept = AST.accept ast
302             translate = AST.translate ast
303             reserved = AST.reserved ast
304             overlay = AST.overlay ast
305         netNodeType <- maybe (return NetAST.Other) (transform context) nodeType
306         netAccept <- transform context accept
307         netTranslate <- transform context translate
308         netReserved <- transform context reserved
309         let
310             mapBlocks = map NetAST.srcBlock netTranslate
311             nodeContext = context
312                 { mappedBlocks = netAccept ++ mapBlocks ++ netReserved }
313         netOverlay <- case overlay of
314                 Nothing -> return []
315                 Just o  -> transform nodeContext o
316         return NetAST.NodeSpec
317             { NetAST.nodeType  = netNodeType
318             , NetAST.accept    = netAccept
319             , NetAST.translate = netTranslate ++ netOverlay
320             }
321
322 instance NetTransformable AST.NodeType NetAST.NodeType where
323     transform _ AST.Memory = return NetAST.Memory
324     transform _ AST.Device = return NetAST.Device
325
326 instance NetTransformable AST.BlockSpec NetAST.BlockSpec where
327     transform context (AST.SingletonBlock address) = do
328         netAddress <- transform context address
329         return NetAST.BlockSpec
330             { NetAST.base  = netAddress
331             , NetAST.limit = netAddress
332             }
333     transform context (AST.RangeBlock base limit) = do
334         netBase <- transform context base
335         netLimit <- transform context limit
336         return NetAST.BlockSpec
337             { NetAST.base  = netBase
338             , NetAST.limit = netLimit
339             }
340     transform context (AST.LengthBlock base bits) = do
341         netBase <- transform context base
342         let
343             baseAddress = NetAST.address netBase
344             limit = baseAddress + 2^bits - 1
345             netLimit = NetAST.Address limit
346         return NetAST.BlockSpec
347             { NetAST.base  = netBase
348             , NetAST.limit = netLimit
349             }
350
351 instance NetTransformable AST.MapSpec NetAST.MapSpec where
352     transform context ast = do
353         let
354             block = AST.block ast
355             destNode = AST.destNode ast
356             destBase = fromMaybe (AST.LiteralAddress 0) (AST.destBase ast)
357         netBlock <- transform context block
358         netDestNode <- transform context destNode
359         netDestBase <- transform context destBase
360         return NetAST.MapSpec
361             { NetAST.srcBlock = netBlock
362             , NetAST.destNode = netDestNode
363             , NetAST.destBase = netDestBase
364             }
365
366 instance NetTransformable AST.OverlaySpec [NetAST.MapSpec] where
367     transform context ast = do
368         let
369             over = AST.over ast
370             width = AST.width ast
371             blocks = mappedBlocks context
372         netOver <- transform context over
373         let
374             maps = overlayMaps netOver width blocks
375         return maps
376
377 overlayMaps :: NetAST.NodeId -> Integer ->[NetAST.BlockSpec] -> [NetAST.MapSpec]
378 overlayMaps destId width blocks =
379     let
380         blockPoints = concat $ map toScanPoints blocks
381         maxAddress = 2^width
382         overStop  = BlockStart $ maxAddress
383         scanPoints = filter ((maxAddress >=) . address) $ sort (overStop:blockPoints)
384         startState = ScanLineState
385             { insideBlocks    = 0
386             , startAddress    = 0
387             }
388     in evalState (scanLine scanPoints []) startState
389     where
390         toScanPoints (NetAST.BlockSpec base limit) =
391                 [ BlockStart $ NetAST.address base
392                 , BlockEnd   $ NetAST.address limit
393                 ]
394         scanLine [] ms = return ms
395         scanLine (p:ps) ms = do
396             maps <- pointAction p ms
397             scanLine ps maps
398         pointAction (BlockStart a) ms = do
399             s <- get       
400             let
401                 i = insideBlocks s
402                 base = startAddress s
403                 limit = a - 1
404             maps <- if (i == 0) && (base <= limit)
405                 then
406                     let
407                         baseAddress = NetAST.Address $ startAddress s
408                         limitAddress = NetAST.Address $ a - 1
409                         srcBlock = NetAST.BlockSpec baseAddress limitAddress
410                         m = NetAST.MapSpec srcBlock destId baseAddress
411                     in return $ m:ms
412                 else return ms
413             modify (\s -> s { insideBlocks = i + 1})
414             return maps
415         pointAction (BlockEnd a) ms = do
416             s <- get
417             let
418                 i = insideBlocks s
419             put $ ScanLineState (i - 1) (a + 1)
420             return ms
421
422 data StoppingPoint
423     = BlockStart { address :: !Integer }
424     | BlockEnd   { address :: !Integer }
425     deriving (Eq, Show)
426
427 instance Ord StoppingPoint where
428     (<=) (BlockStart a1) (BlockEnd   a2)
429         | a1 == a2 = True
430         | otherwise = a1 <= a2
431     (<=) (BlockEnd   a1) (BlockStart a2)
432         | a1 == a2 = False
433         | otherwise = a1 <= a2
434     (<=) sp1 sp2 = (address sp1) <= (address sp2)
435
436 data ScanLineState
437     = ScanLineState
438         { insideBlocks :: !Integer
439         , startAddress :: !Integer
440         } deriving (Show)
441
442 instance NetTransformable AST.Address NetAST.Address where
443     transform _ (AST.LiteralAddress value) = do
444         return $ NetAST.Address value
445     transform context (AST.ParamAddress name) = do
446         let
447             value = getParamValue context name
448         return $ NetAST.Address value
449
450 instance NetTransformable a b => NetTransformable (AST.For a) [b] where
451     transform context ast = do
452         let
453             body = AST.body ast
454             varRanges = AST.varRanges ast
455         concreteRanges <- transform context varRanges
456         let
457             valueList = Map.foldWithKey iterations [] concreteRanges
458             iterContexts = map iterationContext valueList
459             ts = map (\c -> transform c body) iterContexts
460             fs = lefts ts
461             bs = rights ts
462         case fs of
463             [] -> return $ bs
464             _  -> Left $ CheckFailure (concat $ map failures fs)
465         where
466             iterations k vs [] = [Map.fromList [(k,v)] | v <- vs]
467             iterations k vs ms = concat $ map (f ms k) vs
468                 where
469                     f ms k v = map (Map.insert k v) ms
470             iterationContext varMap =
471                 let
472                     values = varValues context
473                 in context
474                     { varValues = values `Map.union` varMap }
475
476 instance NetTransformable AST.ForRange [Integer] where
477     transform context ast = do
478         let
479             start = AST.start ast
480             end = AST.end ast
481         startVal <- transform context start
482         endVal <- transform context end
483         return [startVal..endVal]
484
485 instance NetTransformable AST.ForLimit Integer where
486     transform _ (AST.LiteralLimit value) = return value
487     transform context (AST.ParamLimit name) = return $ getParamValue context name
488
489 instance NetTransformable a b => NetTransformable [a] [b] where
490     transform context ast = do
491         let
492             ts = map (transform context) ast
493             fs = lefts ts
494             bs = rights ts
495         case fs of
496             [] -> return bs
497             _  -> Left $ CheckFailure (concat $ map failures fs)
498
499 instance (Ord k, NetTransformable a b) => NetTransformable (Map k a) (Map k b) where
500     transform context ast = do
501         let
502             ks = Map.keys ast
503             es = Map.elems ast
504         ts <- transform context es
505         return $ Map.fromList (zip ks ts)
506
507 --
508 -- Checks
509 --
510 class NetCheckable a where
511     check :: Set NetAST.NodeId -> a -> Either CheckFailure ()
512
513 instance NetCheckable NetAST.NetSpec where
514     check _ (NetAST.NetSpec net) = do
515         let
516             specContext = Map.keysSet net
517         check specContext $ Map.elems net
518
519 instance NetCheckable NetAST.NodeSpec where
520     check context net = do
521         let
522             translate = NetAST.translate net
523         check context translate
524
525 instance NetCheckable NetAST.MapSpec where
526     check context net = do
527         let
528            destNode = NetAST.destNode net
529         check context destNode
530
531 instance NetCheckable NetAST.NodeId where
532     check context net = do
533         if net `Set.member` context
534             then return ()
535             else Left $ CheckFailure [UndefinedReference $ show net]
536
537 instance NetCheckable a => NetCheckable [a] where
538     check context net = do
539         let
540             checked = map (check context) net
541             fs = lefts $ checked
542         case fs of
543             [] -> return ()
544             _  -> Left $ CheckFailure (concat $ map failures fs)
545
546 getModule :: Context -> String -> AST.Module
547 getModule context name =
548     let
549         modules = AST.modules $ spec context
550     in modules Map.! name
551
552 getParamValue :: Context -> String -> Integer
553 getParamValue context name =
554     let
555         params = paramValues context
556     in params Map.! name
557
558 getVarValue :: Context -> String -> Integer
559 getVarValue context name =
560     let
561         vars = varValues context
562     in vars Map.! name
563
564 checkDuplicates :: (Eq a, Show a) => [a] -> (String -> FailedCheck) -> Either CheckFailure ()
565 checkDuplicates nodeIds fail = do
566     let
567         duplicates = duplicateNames nodeIds
568     case duplicates of
569         [] -> return ()
570         _  -> Left $ CheckFailure (map (fail . show) duplicates)
571     where
572         duplicateNames [] = []
573         duplicateNames (x:xs)
574             | x `elem` xs = nub $ [x] ++ duplicateNames xs
575             | otherwise = duplicateNames xs