Sockeye: clean up
[barrelfish] / tools / sockeye / SockeyeNetBuilder.hs
1 {-
2     SockeyeNetBuilder.hs: Decoding net builder for Sockeye
3
4     Part of Sockeye
5
6     Copyright (c) 2017, ETH Zurich.
7
8     All rights reserved.
9
10     This file is distributed under the terms in the attached LICENSE file.
11     If you do not find this file, copies can be found by writing to:
12     ETH Zurich D-INFK, CAB F.78, Universitaetstr. 6, CH-8092 Zurich,
13     Attn: Systems Group.
14 -}
15
16 {-# LANGUAGE MultiParamTypeClasses #-}
17 {-# LANGUAGE FlexibleInstances #-}
18 {-# LANGUAGE FlexibleContexts #-}
19
20 module SockeyeNetBuilder
21 ( sockeyeBuildNet ) where
22
23 import Control.Monad.State
24
25 import Data.Either
26 import Data.List (nub, intercalate, sort)
27 import Data.Map (Map)
28 import qualified Data.Map as Map
29 import Data.Maybe (fromMaybe)
30 import Data.Set (Set)
31 import qualified Data.Set as Set
32
33 import qualified SockeyeAST as AST
34 import qualified SockeyeASTDecodingNet as NetAST
35
36 type NetNodeDecl = (NetAST.NodeId, NetAST.NodeSpec)
37 type NetList = [NetNodeDecl]
38 type PortMap = [(String, NetAST.NodeId)]
39
40 data FailedCheck
41     = ModuleInstLoop [String]
42     | DuplicateInPort !String !String
43     | DuplicateInMap !String !String
44     | UndefinedInPort !String !String
45     | DuplicateOutPort !String !String
46     | DuplicateOutMap !String !String
47     | UndefinedOutPort !String !String
48     | DuplicateIdentifer !String
49     | UndefinedReference !String
50
51 instance Show FailedCheck where
52     show (ModuleInstLoop loop) = concat ["Module instantiation loop:'", intercalate "' -> '" loop, "'"]
53     show (DuplicateInPort  modName port) = concat ["Multiple declarations of input port '", port, "' in '", modName, "'"]
54     show (DuplicateInMap   ns      port) = concat ["Multiple mappings for input port '", port, "' in '", ns, "'"]
55     show (UndefinedInPort  modName port) = concat ["'", port, "' is not an input port in '", modName, "'"]
56     show (DuplicateOutPort modName port) = concat ["Multiple declarations of output port '", port, "' in '", modName, "'"]
57     show (DuplicateOutMap   ns      port) = concat ["Multiple mappings for output port '", port, "' in '", ns, "'"]
58     show (UndefinedOutPort modName port) = concat ["'", port, "' is not an output port in '", modName, "'"]
59     show (DuplicateIdentifer ident)   = concat ["Multiple declarations of node '", show ident, "'"]
60     show (UndefinedReference ident)   = concat ["Reference to undefined node '", show ident, "'"]
61
62 newtype CheckFailure = CheckFailure
63     { failures :: [FailedCheck] }
64
65 instance Show CheckFailure where
66     show (CheckFailure fs) = unlines $ "":(map show fs)
67
68 data Context = Context
69     { spec         :: AST.SockeyeSpec
70     , modulePath   :: [String]
71     , curNamespace :: NetAST.Namespace
72     , paramValues  :: Map String Integer
73     , varValues    :: Map String Integer
74     , inPortMaps   :: Map String NetAST.NodeId
75     , outPortMaps  :: Map String NetAST.NodeId
76     , mappedBlocks :: [NetAST.BlockSpec]
77     }
78
79 sockeyeBuildNet :: AST.SockeyeSpec -> Either CheckFailure NetAST.NetSpec
80 sockeyeBuildNet ast = do
81     let
82         context = Context
83             { spec         = AST.SockeyeSpec Map.empty
84             , modulePath   = []
85             , curNamespace = NetAST.Namespace []
86             , paramValues  = Map.empty
87             , varValues    = Map.empty
88             , inPortMaps   = Map.empty
89             , outPortMaps  = Map.empty
90             , mappedBlocks = []
91             }        
92     net <- transform context ast
93     check Set.empty net
94     return net
95 --            
96 -- Build net
97 --
98 class NetTransformable a b where
99     transform :: Context -> a -> Either CheckFailure b
100
101 instance NetTransformable AST.SockeyeSpec NetAST.NetSpec where
102     transform context ast = do
103         let
104             rootInst = AST.ModuleInst
105                 { AST.namespace  = AST.SimpleIdent ""
106                 , AST.moduleName = "@root"
107                 , AST.arguments  = Map.empty
108                 , AST.inPortMap  = []
109                 , AST.outPortMap = []
110                 }
111             specContext = context
112                 { spec = ast }
113         netList <- transform specContext rootInst
114         let
115             nodeIds = map fst netList
116         checkDuplicates nodeIds DuplicateIdentifer
117         let
118             nodeMap = Map.fromList netList
119         return $ NetAST.NetSpec nodeMap
120
121 instance NetTransformable AST.Module NetList where
122     transform context ast = do
123         let
124             inPorts = AST.inputPorts ast
125             outPorts = AST.outputPorts ast
126             nodeDecls = AST.nodeDecls ast
127             moduleInsts = AST.moduleInsts ast
128         inDecls <- do
129             net <- transform context inPorts
130             return $ concat (net :: [NetList])
131         outDecls <- do
132             net <- transform context outPorts
133             return $ concat (net :: [NetList])
134         -- TODO check duplicate ports
135         -- TODO check mappings to non existing port
136         netDecls <- transform context nodeDecls
137         netInsts <- transform context moduleInsts
138         return $ concat (inDecls:outDecls:netDecls ++ netInsts)            
139
140 instance NetTransformable AST.Port NetList where
141     transform context (AST.MultiPort for) = do
142         netPorts <- transform context for
143         return $ concat (netPorts :: [NetList])
144     transform context (AST.InputPort portId portWidth) = do
145         netPortId <- transform context portId
146         let
147             portMap = inPortMaps context
148             name = NetAST.name netPortId
149             mappedId = Map.lookup name portMap
150         case mappedId of
151             Nothing    -> return []
152             Just ident -> do
153                 let
154                     node = portNode netPortId portWidth
155                 return [(ident, node)]
156     transform context (AST.OutputPort portId portWidth) = do
157         netPortId <- transform context portId
158         let
159             portMap = outPortMaps context
160             name = NetAST.name netPortId
161             mappedId = Map.lookup name portMap
162         case mappedId of
163             Nothing    -> return [(netPortId, portNodeTemplate)]
164             Just ident -> do
165                 let
166                     node = portNode ident portWidth
167                 return [(netPortId, node)]
168
169 portNode :: NetAST.NodeId -> Integer -> NetAST.NodeSpec
170 portNode destNode width =
171     let
172         base = NetAST.Address 0
173         limit = NetAST.Address $ 2^width - 1
174         srcBlock = NetAST.BlockSpec
175             { NetAST.base  = base
176             , NetAST.limit = limit
177             }
178         map = NetAST.MapSpec
179                 { NetAST.srcBlock = srcBlock
180                 , NetAST.destNode = destNode
181                 , NetAST.destBase = base
182                 }
183     in portNodeTemplate { NetAST.translate = [map] }
184
185 portNodeTemplate :: NetAST.NodeSpec
186 portNodeTemplate = NetAST.NodeSpec
187     { NetAST.nodeType  = NetAST.Other
188     , NetAST.accept    = []
189     , NetAST.translate = []
190     }    
191
192 instance NetTransformable AST.ModuleInst NetList where
193     transform context (AST.MultiModuleInst for) = do
194         net <- transform context for
195         return $ concat (net :: [NetList])
196     transform context ast = do
197         let
198             namespace = AST.namespace ast
199             name = AST.moduleName ast
200             args = AST.arguments ast
201             inPortMap = AST.inPortMap ast
202             outPortMap = AST.outPortMap ast
203             mod = getModule context name
204         checkSelfInst name
205         netNamespace <- transform context namespace
206         netArgs <- transform context args
207         netInMap <- transform context inPortMap
208         netOutMap <- transform context outPortMap
209         let
210             inMaps = concat (netInMap :: [PortMap])
211             outMaps = concat (netOutMap :: [PortMap])
212         checkDuplicates (map fst inMaps) (DuplicateInMap $ show netNamespace) 
213         checkDuplicates (map fst outMaps) (DuplicateOutMap $ show netNamespace)
214         let
215             modContext = moduleContext name netNamespace netArgs inMaps outMaps
216         transform modContext mod
217             where
218                 moduleContext name namespace args inMaps outMaps =
219                     let
220                         path = modulePath context
221                         base = NetAST.ns $ NetAST.namespace namespace
222                         newNs = case NetAST.name namespace of
223                             "" -> NetAST.Namespace base
224                             n  -> NetAST.Namespace $ n:base
225                     in context
226                         { modulePath   = name:path
227                         , curNamespace = newNs
228                         , paramValues  = args
229                         , varValues    = Map.empty
230                         , inPortMaps   = Map.fromList inMaps
231                         , outPortMaps  = Map.fromList outMaps
232                         }
233                 checkSelfInst name = do
234                     let
235                         path = modulePath context
236                     case loop path of
237                         [] -> return ()
238                         l  -> Left $ CheckFailure [ModuleInstLoop (reverse $ name:l)]
239                         where
240                             loop [] = []
241                             loop path@(p:ps)
242                                 | name `elem` path = p:(loop ps)
243                                 | otherwise = []
244
245
246 instance NetTransformable AST.PortMap PortMap where
247     transform context (AST.MultiPortMap for) = do
248         ts <- transform context for
249         return $ concat (ts :: [PortMap])
250     transform context ast = do
251         let
252             mappedId = AST.mappedId ast
253             mappedPort = AST.mappedPort ast
254         netMappedId <- transform context mappedId
255         netMappedPort <- transform context mappedPort
256         return [(NetAST.name netMappedPort, netMappedId)]
257
258 instance NetTransformable AST.ModuleArg Integer where
259     transform _ (AST.AddressArg value) = return value
260     transform _ (AST.NaturalArg value) = return value
261     transform context (AST.ParamArg name) = return $ getParamValue context name
262
263 instance NetTransformable AST.Identifier NetAST.NodeId where
264     transform context ast = do
265         let
266             namespace = curNamespace context
267             name = identName ast
268         return NetAST.NodeId
269             { NetAST.namespace = namespace
270             , NetAST.name      = name
271             }
272             where
273                 identName (AST.SimpleIdent name) = name
274                 identName ident =
275                     let
276                         prefix = AST.prefix ident
277                         varName = AST.varName ident
278                         suffix = AST.suffix ident
279                         varValue = show $ getVarValue context varName
280                         suffixName = case suffix of
281                             Nothing -> ""
282                             Just s  -> identName s
283                     in prefix ++ varValue ++ suffixName
284
285 instance NetTransformable AST.NodeDecl NetList where
286     transform context (AST.MultiNodeDecl for) = do
287         ts <- transform context for
288         return $ concat (ts :: [NetList])
289     transform context ast = do
290         let
291             ident = AST.nodeId ast
292             nodeSpec = AST.nodeSpec ast
293         nodeId <- transform context ident
294         netNodeSpec <- transform context nodeSpec
295         return [(nodeId, netNodeSpec)]
296
297 instance NetTransformable AST.NodeSpec NetAST.NodeSpec where
298     transform context ast = do
299         let
300             nodeType = AST.nodeType ast
301             accept = AST.accept ast
302             translate = AST.translate ast
303             overlay = AST.overlay ast
304         netNodeType <- maybe (return NetAST.Other) (transform context) nodeType
305         netAccept <- transform context accept
306         netTranslate <- transform context translate
307         let
308             mapBlocks = map NetAST.srcBlock netTranslate
309             nodeContext = context
310                 { mappedBlocks = netAccept ++ mapBlocks }
311         netOverlay <- case overlay of
312                 Nothing -> return []
313                 Just o  -> transform nodeContext o
314         return NetAST.NodeSpec
315             { NetAST.nodeType  = netNodeType
316             , NetAST.accept    = netAccept
317             , NetAST.translate = netTranslate ++ netOverlay
318             }
319
320 instance NetTransformable AST.NodeType NetAST.NodeType where
321     transform _ AST.Memory = return NetAST.Memory
322     transform _ AST.Device = return NetAST.Device
323
324 instance NetTransformable AST.BlockSpec NetAST.BlockSpec where
325     transform context (AST.SingletonBlock address) = do
326         netAddress <- transform context address
327         return NetAST.BlockSpec
328             { NetAST.base  = netAddress
329             , NetAST.limit = netAddress
330             }
331     transform context (AST.RangeBlock base limit) = do
332         netBase <- transform context base
333         netLimit <- transform context limit
334         return NetAST.BlockSpec
335             { NetAST.base  = netBase
336             , NetAST.limit = netLimit
337             }
338     transform context (AST.LengthBlock base bits) = do
339         netBase <- transform context base
340         let
341             baseAddress = NetAST.address netBase
342             limit = baseAddress + 2^bits - 1
343             netLimit = NetAST.Address limit
344         return NetAST.BlockSpec
345             { NetAST.base  = netBase
346             , NetAST.limit = netLimit
347             }
348
349 instance NetTransformable AST.MapSpec NetAST.MapSpec where
350     transform context ast = do
351         let
352             block = AST.block ast
353             destNode = AST.destNode ast
354             destBase = fromMaybe (AST.LiteralAddress 0) (AST.destBase ast)
355         netBlock <- transform context block
356         netDestNode <- transform context destNode
357         netDestBase <- transform context destBase
358         return NetAST.MapSpec
359             { NetAST.srcBlock = netBlock
360             , NetAST.destNode = netDestNode
361             , NetAST.destBase = netDestBase
362             }
363
364 instance NetTransformable AST.OverlaySpec [NetAST.MapSpec] where
365     transform context ast = do
366         let
367             over = AST.over ast
368             width = AST.width ast
369             blocks = mappedBlocks context
370         netOver <- transform context over
371         let
372             maps = overlayMaps netOver width blocks
373         return maps
374
375 overlayMaps :: NetAST.NodeId -> Integer ->[NetAST.BlockSpec] -> [NetAST.MapSpec]
376 overlayMaps destId width blocks =
377     let
378         blockPoints = concat $ map toScanPoints blocks
379         maxAddress = 2^width
380         overStop  = BlockStart $ maxAddress
381         scanPoints = filter ((maxAddress >=) . address) $ sort (overStop:blockPoints)
382         startState = ScanLineState
383             { insideBlocks    = 0
384             , startAddress    = 0
385             }
386     in evalState (scanLine scanPoints []) startState
387     where
388         toScanPoints (NetAST.BlockSpec base limit) =
389                 [ BlockStart $ NetAST.address base
390                 , BlockEnd   $ NetAST.address limit
391                 ]
392         scanLine [] ms = return ms
393         scanLine (p:ps) ms = do
394             maps <- pointAction p ms
395             scanLine ps maps
396         pointAction (BlockStart a) ms = do
397             s <- get       
398             let
399                 i = insideBlocks s
400                 base = startAddress s
401                 limit = a - 1
402             maps <- if (i == 0) && (base <= limit)
403                 then
404                     let
405                         baseAddress = NetAST.Address $ startAddress s
406                         limitAddress = NetAST.Address $ a - 1
407                         srcBlock = NetAST.BlockSpec baseAddress limitAddress
408                         m = NetAST.MapSpec srcBlock destId baseAddress
409                     in return $ m:ms
410                 else return ms
411             modify (\s -> s { insideBlocks = i + 1})
412             return maps
413         pointAction (BlockEnd a) ms = do
414             s <- get
415             let
416                 i = insideBlocks s
417             put $ ScanLineState (i - 1) (a + 1)
418             return ms
419
420 data StoppingPoint
421     = BlockStart { address :: !Integer }
422     | BlockEnd   { address :: !Integer }
423     deriving (Eq, Show)
424
425 instance Ord StoppingPoint where
426     (<=) (BlockStart a1) (BlockEnd   a2)
427         | a1 == a2 = True
428         | otherwise = a1 <= a2
429     (<=) (BlockEnd   a1) (BlockStart a2)
430         | a1 == a2 = False
431         | otherwise = a1 <= a2
432     (<=) sp1 sp2 = (address sp1) <= (address sp2)
433
434 data ScanLineState
435     = ScanLineState
436         { insideBlocks :: !Integer
437         , startAddress :: !Integer
438         } deriving (Show)
439
440 instance NetTransformable AST.Address NetAST.Address where
441     transform _ (AST.LiteralAddress value) = do
442         return $ NetAST.Address value
443     transform context (AST.ParamAddress name) = do
444         let
445             value = getParamValue context name
446         return $ NetAST.Address value
447
448 instance NetTransformable a b => NetTransformable (AST.For a) [b] where
449     transform context ast = do
450         let
451             body = AST.body ast
452             varRanges = AST.varRanges ast
453         concreteRanges <- transform context varRanges
454         let
455             valueList = Map.foldWithKey iterations [] concreteRanges
456             iterContexts = map iterationContext valueList
457             ts = map (\c -> transform c body) iterContexts
458             fs = lefts ts
459             bs = rights ts
460         case fs of
461             [] -> return $ bs
462             _  -> Left $ CheckFailure (concat $ map failures fs)
463         where
464             iterations k vs [] = [Map.fromList [(k,v)] | v <- vs]
465             iterations k vs ms = concat $ map (f ms k) vs
466                 where
467                     f ms k v = map (Map.insert k v) ms
468             iterationContext varMap =
469                 let
470                     values = varValues context
471                 in context
472                     { varValues = values `Map.union` varMap }
473
474 instance NetTransformable AST.ForRange [Integer] where
475     transform context ast = do
476         let
477             start = AST.start ast
478             end = AST.end ast
479         startVal <- transform context start
480         endVal <- transform context end
481         return [startVal..endVal]
482
483 instance NetTransformable AST.ForLimit Integer where
484     transform _ (AST.LiteralLimit value) = return value
485     transform context (AST.ParamLimit name) = return $ getParamValue context name
486
487 instance NetTransformable a b => NetTransformable [a] [b] where
488     transform context ast = do
489         let
490             ts = map (transform context) ast
491             fs = lefts ts
492             bs = rights ts
493         case fs of
494             [] -> return bs
495             _  -> Left $ CheckFailure (concat $ map failures fs)
496
497 instance (Ord k, NetTransformable a b) => NetTransformable (Map k a) (Map k b) where
498     transform context ast = do
499         let
500             ks = Map.keys ast
501             es = Map.elems ast
502         ts <- transform context es
503         return $ Map.fromList (zip ks ts)
504
505 --
506 -- Checks
507 --
508 class NetCheckable a where
509     check :: Set NetAST.NodeId -> a -> Either CheckFailure ()
510
511 instance NetCheckable NetAST.NetSpec where
512     check _ (NetAST.NetSpec net) = do
513         let
514             specContext = Map.keysSet net
515         check specContext $ Map.elems net
516
517 instance NetCheckable NetAST.NodeSpec where
518     check context net = do
519         let
520             translate = NetAST.translate net
521         check context translate
522
523 instance NetCheckable NetAST.MapSpec where
524     check context net = do
525         let
526            destNode = NetAST.destNode net
527         check context destNode
528
529 instance NetCheckable NetAST.NodeId where
530     check context net = do
531         if net `Set.member` context
532             then return ()
533             else Left $ CheckFailure [UndefinedReference $ show net]
534
535 instance NetCheckable a => NetCheckable [a] where
536     check context net = do
537         let
538             checked = map (check context) net
539             fs = lefts $ checked
540         case fs of
541             [] -> return ()
542             _  -> Left $ CheckFailure (concat $ map failures fs)
543
544 getModule :: Context -> String -> AST.Module
545 getModule context name =
546     let
547         modules = AST.modules $ spec context
548     in modules Map.! name
549
550 getParamValue :: Context -> String -> Integer
551 getParamValue context name =
552     let
553         params = paramValues context
554     in params Map.! name
555
556 getVarValue :: Context -> String -> Integer
557 getVarValue context name =
558     let
559         vars = varValues context
560     in vars Map.! name
561
562 checkDuplicates :: (Eq a, Show a) => [a] -> (String -> FailedCheck) -> Either CheckFailure ()
563 checkDuplicates nodeIds fail = do
564     let
565         duplicates = duplicateNames nodeIds
566     case duplicates of
567         [] -> return ()
568         _  -> Left $ CheckFailure (map (fail . show) duplicates)
569     where
570         duplicateNames [] = []
571         duplicateNames (x:xs)
572             | x `elem` xs = nub $ [x] ++ duplicateNames xs
573             | otherwise = duplicateNames xs