fe76bb264d41db55c8507f73b95d4a85e98e6412
[barrelfish] / tools / sockeye / SockeyeNetBuilder.hs
1 {-
2     SockeyeNetBuilder.hs: Decoding net builder for Sockeye
3
4     Part of Sockeye
5
6     Copyright (c) 2017, ETH Zurich.
7
8     All rights reserved.
9
10     This file is distributed under the terms in the attached LICENSE file.
11     If you do not find this file, copies can be found by writing to:
12     ETH Zurich D-INFK, CAB F.78, Universitaetstr. 6, CH-8092 Zurich,
13     Attn: Systems Group.
14 -}
15
16 {-# LANGUAGE MultiParamTypeClasses #-}
17 {-# LANGUAGE FlexibleInstances #-}
18 {-# LANGUAGE FlexibleContexts #-}
19
20 module SockeyeNetBuilder
21 ( sockeyeBuildNet ) where
22
23 import Control.Monad.State
24
25 import Data.Either
26 import Data.List (nub, intercalate, sort)
27 import Data.Map (Map)
28 import qualified Data.Map as Map
29 import Data.Maybe (fromMaybe)
30 import Data.Set (Set)
31 import qualified Data.Set as Set
32
33 import qualified SockeyeAST as AST
34 import qualified SockeyeASTDecodingNet as NetAST
35
36 type NetNodeDecl = (NetAST.NodeId, NetAST.NodeSpec)
37 type NetList = [NetNodeDecl]
38 type PortMap = [(String, NetAST.NodeId)]
39
40 data FailedCheck
41     = ModuleInstLoop [String]
42     | DuplicateInPort !String !String
43     | DuplicateInMap !String !String
44     | UndefinedInPort !String !String
45     | DuplicateOutPort !String !String
46     | DuplicateOutMap !String !String
47     | UndefinedOutPort !String !String
48     | DuplicateIdentifer !String
49     | UndefinedReference !String
50
51 instance Show FailedCheck where
52     show (ModuleInstLoop loop) = concat ["Module instantiation loop:'", intercalate "' -> '" loop, "'"]
53     show (DuplicateInPort  modName port) = concat ["Multiple declarations of input port '", port, "' in '", modName, "'"]
54     show (DuplicateInMap   ns      port) = concat ["Multiple mappings for input port '", port, "' in '", ns, "'"]
55     show (UndefinedInPort  modName port) = concat ["'", port, "' is not an input port in '", modName, "'"]
56     show (DuplicateOutPort modName port) = concat ["Multiple declarations of output port '", port, "' in '", modName, "'"]
57     show (DuplicateOutMap   ns      port) = concat ["Multiple mappings for output port '", port, "' in '", ns, "'"]
58     show (UndefinedOutPort modName port) = concat ["'", port, "' is not an output port in '", modName, "'"]
59     show (DuplicateIdentifer ident)   = concat ["Multiple declarations of node '", show ident, "'"]
60     show (UndefinedReference ident)   = concat ["Reference to undefined node '", show ident, "'"]
61
62 newtype CheckFailure = CheckFailure
63     { failures :: [FailedCheck] }
64
65 instance Show CheckFailure where
66     show (CheckFailure fs) = unlines $ "":(map show fs)
67
68 data Context = Context
69     { spec         :: AST.SockeyeSpec
70     , modulePath   :: [String]
71     , curNamespace :: NetAST.Namespace
72     , paramValues  :: Map String Integer
73     , varValues    :: Map String Integer
74     , inPortMaps   :: Map String NetAST.NodeId
75     , outPortMaps  :: Map String NetAST.NodeId
76     , mappedBlocks :: [NetAST.BlockSpec]
77     }
78
79 sockeyeBuildNet :: AST.SockeyeSpec -> Either CheckFailure NetAST.NetSpec
80 sockeyeBuildNet ast = do
81     let
82         context = Context
83             { spec         = AST.SockeyeSpec Map.empty
84             , modulePath   = []
85             , curNamespace = NetAST.Namespace []
86             , paramValues  = Map.empty
87             , varValues    = Map.empty
88             , inPortMaps   = Map.empty
89             , outPortMaps  = Map.empty
90             , mappedBlocks = []
91             }        
92     net <- transform context ast
93     check Set.empty net
94     return net
95 --            
96 -- Build net
97 --
98 class NetTransformable a b where
99     transform :: Context -> a -> Either CheckFailure b
100
101 instance NetTransformable AST.SockeyeSpec NetAST.NetSpec where
102     transform context ast = do
103         let
104             rootInst = AST.ModuleInst
105                 { AST.namespace  = AST.SimpleIdent ""
106                 , AST.moduleName = "@root"
107                 , AST.arguments  = Map.empty
108                 , AST.inPortMap  = []
109                 , AST.outPortMap = []
110                 }
111             specContext = context
112                 { spec = ast }
113         netList <- transform specContext rootInst
114         let
115             nodeIds = map fst netList
116         checkDuplicates nodeIds DuplicateIdentifer
117         let
118             nodeMap = Map.fromList netList
119         return $ NetAST.NetSpec nodeMap
120
121 instance NetTransformable AST.Module NetList where
122     transform context ast = do
123         let
124             inPorts = AST.inputPorts ast
125             outPorts = AST.outputPorts ast
126             nodeDecls = AST.nodeDecls ast
127             moduleInsts = AST.moduleInsts ast
128         inDecls <- do
129             net <- transform context inPorts
130             return $ concat (net :: [NetList])
131         outDecls <- do
132             net <- transform context outPorts
133             return $ concat (net :: [NetList])
134         -- TODO check duplicate ports
135         -- TODO check mappings to non existing port
136         netDecls <- transform context nodeDecls
137         netInsts <- transform context moduleInsts
138         return $ concat (inDecls:outDecls:netDecls ++ netInsts)            
139
140 instance NetTransformable AST.Port NetList where
141     transform context (AST.MultiPort for) = do
142         netPorts <- transform context for
143         return $ concat (netPorts :: [NetList])
144     transform context (AST.InputPort portId portWidth) = do
145         netPortId <- transform context portId
146         let
147             portMap = inPortMaps context
148             name = NetAST.name netPortId
149             mappedId = Map.lookup name portMap
150         case mappedId of
151             Nothing    -> return []
152             Just ident -> do
153                 let
154                     node = portNode netPortId portWidth
155                 return [(ident, node)]
156     transform context (AST.OutputPort portId portWidth) = do
157         netPortId <- transform context portId
158         let
159             portMap = outPortMaps context
160             name = NetAST.name netPortId
161             mappedId = Map.lookup name portMap
162         case mappedId of
163             Nothing    -> return [(netPortId, portNodeTemplate)]
164             Just ident -> do
165                 let
166                     node = portNode ident portWidth
167                 return [(netPortId, node)]
168
169 portNode :: NetAST.NodeId -> Integer -> NetAST.NodeSpec
170 portNode destNode width =
171     let
172         base = NetAST.Address 0
173         limit = NetAST.Address $ 2^width - 1
174         srcBlock = NetAST.BlockSpec
175             { NetAST.base  = base
176             , NetAST.limit = limit
177             }
178         map = NetAST.MapSpec
179                 { NetAST.srcBlock = srcBlock
180                 , NetAST.destNode = destNode
181                 , NetAST.destBase = base
182                 }
183     in portNodeTemplate { NetAST.translate = [map] }
184
185 portNodeTemplate :: NetAST.NodeSpec
186 portNodeTemplate = NetAST.NodeSpec
187     { NetAST.nodeType  = NetAST.Other
188     , NetAST.accept    = []
189     , NetAST.translate = []
190     }    
191
192 instance NetTransformable AST.ModuleInst NetList where
193     transform context (AST.MultiModuleInst for) = do
194         net <- transform context for
195         return $ concat (net :: [NetList])
196     transform context ast = do
197         let
198             namespace = AST.namespace ast
199             name = AST.moduleName ast
200             args = AST.arguments ast
201             inPortMap = AST.inPortMap ast
202             outPortMap = AST.outPortMap ast
203             mod = getModule context name
204         checkSelfInst name
205         netNamespace <- transform context namespace
206         netArgs <- transform context args
207         netInMap <- transform context inPortMap
208         netOutMap <- transform context outPortMap
209         let
210             inMaps = concat (netInMap :: [PortMap])
211             outMaps = concat (netOutMap :: [PortMap])
212         checkDuplicates (map fst inMaps) (DuplicateInMap $ show netNamespace) 
213         checkDuplicates (map fst outMaps) (DuplicateOutMap $ show netNamespace)
214         let
215             modContext = moduleContext name netNamespace netArgs inMaps outMaps
216         transform modContext mod
217             where
218                 moduleContext name namespace args inMaps outMaps =
219                     let
220                         path = modulePath context
221                         base = NetAST.ns $ NetAST.namespace namespace
222                         newNs = case NetAST.name namespace of
223                             "" -> NetAST.Namespace base
224                             n  -> NetAST.Namespace $ n:base
225                     in context
226                         { modulePath   = name:path
227                         , curNamespace = newNs
228                         , paramValues  = args
229                         , varValues    = Map.empty
230                         , inPortMaps   = Map.fromList inMaps
231                         , outPortMaps  = Map.fromList outMaps
232                         }
233                 checkSelfInst name = do
234                     let
235                         path = modulePath context
236                     case loop path of
237                         [] -> return ()
238                         l  -> Left $ CheckFailure [ModuleInstLoop (reverse $ name:l)]
239                         where
240                             loop [] = []
241                             loop path@(p:ps)
242                                 | name `elem` path = p:(loop ps)
243                                 | otherwise = []
244
245
246 instance NetTransformable AST.PortMap PortMap where
247     transform context (AST.MultiPortMap for) = do
248         ts <- transform context for
249         return $ concat (ts :: [PortMap])
250     transform context ast = do
251         let
252             mappedId = AST.mappedId ast
253             mappedPort = AST.mappedPort ast
254         netMappedId <- transform context mappedId
255         netMappedPort <- transform context mappedPort
256         return [(NetAST.name netMappedPort, netMappedId)]
257
258 instance NetTransformable AST.ModuleArg Integer where
259     transform _ (AST.AddressArg value) = return value
260     transform _ (AST.NaturalArg value) = return value
261     transform context (AST.ParamArg name) = return $ getParamValue context name
262
263 instance NetTransformable AST.Identifier NetAST.NodeId where
264     transform context ast = do
265         let
266             namespace = curNamespace context
267             name = identName ast
268         return NetAST.NodeId
269             { NetAST.namespace = namespace
270             , NetAST.name      = name
271             }
272             where
273                 identName (AST.SimpleIdent name) = name
274                 identName ident =
275                     let
276                         prefix = AST.prefix ident
277                         varName = AST.varName ident
278                         suffix = AST.suffix ident
279                         varValue = show $ getVarValue context varName
280                         suffixName = case suffix of
281                             Nothing -> ""
282                             Just s  -> identName s
283                     in prefix ++ varValue ++ suffixName
284
285 instance NetTransformable AST.NodeDecl NetList where
286     transform context (AST.MultiNodeDecl for) = do
287         ts <- transform context for
288         return $ concat (ts :: [NetList])
289     transform context ast = do
290         let
291             ident = AST.nodeId ast
292             nodeSpec = AST.nodeSpec ast
293         nodeId <- transform context ident
294         netNodeSpec <- transform context nodeSpec
295         return [(nodeId, netNodeSpec)]
296
297 instance NetTransformable AST.NodeSpec NetAST.NodeSpec where
298     transform context ast = do
299         let
300             nodeType = AST.nodeType ast
301             accept = AST.accept ast
302             translate = AST.translate ast
303             reserved = AST.reserved ast
304             overlay = AST.overlay ast
305         netNodeType <- transform context nodeType
306         netAccept <- transform context accept
307         netTranslate <- transform context translate
308         netReserved <- transform context reserved
309         let
310             mapBlocks = map NetAST.srcBlock netTranslate
311             nodeContext = context
312                 { mappedBlocks = netAccept ++ mapBlocks ++ netReserved }
313         netOverlay <- case overlay of
314                 Nothing -> return []
315                 Just o  -> transform nodeContext o
316         return NetAST.NodeSpec
317             { NetAST.nodeType  = netNodeType
318             , NetAST.accept    = netAccept
319             , NetAST.translate = netTranslate ++ netOverlay
320             }
321
322 instance NetTransformable AST.NodeType NetAST.NodeType where
323     transform _ AST.Memory = return NetAST.Memory
324     transform _ AST.Device = return NetAST.Device
325     transform _ AST.Other  = return NetAST.Other
326
327 instance NetTransformable AST.BlockSpec NetAST.BlockSpec where
328     transform context (AST.SingletonBlock address) = do
329         netAddress <- transform context address
330         return NetAST.BlockSpec
331             { NetAST.base  = netAddress
332             , NetAST.limit = netAddress
333             }
334     transform context (AST.RangeBlock base limit) = do
335         netBase <- transform context base
336         netLimit <- transform context limit
337         return NetAST.BlockSpec
338             { NetAST.base  = netBase
339             , NetAST.limit = netLimit
340             }
341     transform context (AST.LengthBlock base bits) = do
342         netBase <- transform context base
343         let
344             baseAddress = NetAST.address netBase
345             limit = baseAddress + 2^bits - 1
346             netLimit = NetAST.Address limit
347         return NetAST.BlockSpec
348             { NetAST.base  = netBase
349             , NetAST.limit = netLimit
350             }
351
352 instance NetTransformable AST.MapSpec NetAST.MapSpec where
353     transform context ast = do
354         let
355             block = AST.block ast
356             destNode = AST.destNode ast
357             destBase = fromMaybe (AST.LiteralAddress 0) (AST.destBase ast)
358         netBlock <- transform context block
359         netDestNode <- transform context destNode
360         netDestBase <- transform context destBase
361         return NetAST.MapSpec
362             { NetAST.srcBlock = netBlock
363             , NetAST.destNode = netDestNode
364             , NetAST.destBase = netDestBase
365             }
366
367 instance NetTransformable AST.OverlaySpec [NetAST.MapSpec] where
368     transform context ast = do
369         let
370             over = AST.over ast
371             width = AST.width ast
372             blocks = mappedBlocks context
373         netOver <- transform context over
374         let
375             maps = overlayMaps netOver width blocks
376         return maps
377
378 overlayMaps :: NetAST.NodeId -> Integer ->[NetAST.BlockSpec] -> [NetAST.MapSpec]
379 overlayMaps destId width blocks =
380     let
381         blockPoints = concat $ map toScanPoints blocks
382         maxAddress = 2^width
383         overStop  = BlockStart $ maxAddress
384         scanPoints = filter ((maxAddress >=) . address) $ sort (overStop:blockPoints)
385         startState = ScanLineState
386             { insideBlocks    = 0
387             , startAddress    = 0
388             }
389     in evalState (scanLine scanPoints []) startState
390     where
391         toScanPoints (NetAST.BlockSpec base limit) =
392                 [ BlockStart $ NetAST.address base
393                 , BlockEnd   $ NetAST.address limit
394                 ]
395         scanLine [] ms = return ms
396         scanLine (p:ps) ms = do
397             maps <- pointAction p ms
398             scanLine ps maps
399         pointAction (BlockStart a) ms = do
400             s <- get       
401             let
402                 i = insideBlocks s
403                 base = startAddress s
404                 limit = a - 1
405             maps <- if (i == 0) && (base <= limit)
406                 then
407                     let
408                         baseAddress = NetAST.Address $ startAddress s
409                         limitAddress = NetAST.Address $ a - 1
410                         srcBlock = NetAST.BlockSpec baseAddress limitAddress
411                         m = NetAST.MapSpec srcBlock destId baseAddress
412                     in return $ m:ms
413                 else return ms
414             modify (\s -> s { insideBlocks = i + 1})
415             return maps
416         pointAction (BlockEnd a) ms = do
417             s <- get
418             let
419                 i = insideBlocks s
420             put $ ScanLineState (i - 1) (a + 1)
421             return ms
422
423 data StoppingPoint
424     = BlockStart { address :: !Integer }
425     | BlockEnd   { address :: !Integer }
426     deriving (Eq, Show)
427
428 instance Ord StoppingPoint where
429     (<=) (BlockStart a1) (BlockEnd   a2)
430         | a1 == a2 = True
431         | otherwise = a1 <= a2
432     (<=) (BlockEnd   a1) (BlockStart a2)
433         | a1 == a2 = False
434         | otherwise = a1 <= a2
435     (<=) sp1 sp2 = (address sp1) <= (address sp2)
436
437 data ScanLineState
438     = ScanLineState
439         { insideBlocks :: !Integer
440         , startAddress :: !Integer
441         } deriving (Show)
442
443 instance NetTransformable AST.Address NetAST.Address where
444     transform _ (AST.LiteralAddress value) = do
445         return $ NetAST.Address value
446     transform context (AST.ParamAddress name) = do
447         let
448             value = getParamValue context name
449         return $ NetAST.Address value
450
451 instance NetTransformable a b => NetTransformable (AST.For a) [b] where
452     transform context ast = do
453         let
454             body = AST.body ast
455             varRanges = AST.varRanges ast
456         concreteRanges <- transform context varRanges
457         let
458             valueList = Map.foldWithKey iterations [] concreteRanges
459             iterContexts = map iterationContext valueList
460             ts = map (\c -> transform c body) iterContexts
461             fs = lefts ts
462             bs = rights ts
463         case fs of
464             [] -> return $ bs
465             _  -> Left $ CheckFailure (concat $ map failures fs)
466         where
467             iterations k vs [] = [Map.fromList [(k,v)] | v <- vs]
468             iterations k vs ms = concat $ map (f ms k) vs
469                 where
470                     f ms k v = map (Map.insert k v) ms
471             iterationContext varMap =
472                 let
473                     values = varValues context
474                 in context
475                     { varValues = values `Map.union` varMap }
476
477 instance NetTransformable AST.ForRange [Integer] where
478     transform context ast = do
479         let
480             start = AST.start ast
481             end = AST.end ast
482         startVal <- transform context start
483         endVal <- transform context end
484         return [startVal..endVal]
485
486 instance NetTransformable AST.ForLimit Integer where
487     transform _ (AST.LiteralLimit value) = return value
488     transform context (AST.ParamLimit name) = return $ getParamValue context name
489
490 instance NetTransformable a b => NetTransformable [a] [b] where
491     transform context ast = do
492         let
493             ts = map (transform context) ast
494             fs = lefts ts
495             bs = rights ts
496         case fs of
497             [] -> return bs
498             _  -> Left $ CheckFailure (concat $ map failures fs)
499
500 instance (Ord k, NetTransformable a b) => NetTransformable (Map k a) (Map k b) where
501     transform context ast = do
502         let
503             ks = Map.keys ast
504             es = Map.elems ast
505         ts <- transform context es
506         return $ Map.fromList (zip ks ts)
507
508 --
509 -- Checks
510 --
511 class NetCheckable a where
512     check :: Set NetAST.NodeId -> a -> Either CheckFailure ()
513
514 instance NetCheckable NetAST.NetSpec where
515     check _ (NetAST.NetSpec net) = do
516         let
517             specContext = Map.keysSet net
518         check specContext $ Map.elems net
519
520 instance NetCheckable NetAST.NodeSpec where
521     check context net = do
522         let
523             translate = NetAST.translate net
524         check context translate
525
526 instance NetCheckable NetAST.MapSpec where
527     check context net = do
528         let
529            destNode = NetAST.destNode net
530         check context destNode
531
532 instance NetCheckable NetAST.NodeId where
533     check context net = do
534         if net `Set.member` context
535             then return ()
536             else Left $ CheckFailure [UndefinedReference $ show net]
537
538 instance NetCheckable a => NetCheckable [a] where
539     check context net = do
540         let
541             checked = map (check context) net
542             fs = lefts $ checked
543         case fs of
544             [] -> return ()
545             _  -> Left $ CheckFailure (concat $ map failures fs)
546
547 getModule :: Context -> String -> AST.Module
548 getModule context name =
549     let
550         modules = AST.modules $ spec context
551     in modules Map.! name
552
553 getParamValue :: Context -> String -> Integer
554 getParamValue context name =
555     let
556         params = paramValues context
557     in params Map.! name
558
559 getVarValue :: Context -> String -> Integer
560 getVarValue context name =
561     let
562         vars = varValues context
563     in vars Map.! name
564
565 checkDuplicates :: (Eq a, Show a) => [a] -> (String -> FailedCheck) -> Either CheckFailure ()
566 checkDuplicates nodeIds fail = do
567     let
568         duplicates = duplicateNames nodeIds
569     case duplicates of
570         [] -> return ()
571         _  -> Left $ CheckFailure (map (fail . show) duplicates)
572     where
573         duplicateNames [] = []
574         duplicateNames (x:xs)
575             | x `elem` xs = nub $ [x] ++ duplicateNames xs
576             | otherwise = duplicateNames xs