Sockeye: Handle arbitrary large numbers
[barrelfish] / tools / sockeye / SockeyeNetBuilder.hs
1 {-
2     SockeyeNetBuilder.hs: Decoding net builder for Sockeye
3
4     Part of Sockeye
5
6     Copyright (c) 2017, ETH Zurich.
7
8     All rights reserved.
9
10     This file is distributed under the terms in the attached LICENSE file.
11     If you do not find this file, copies can be found by writing to:
12     ETH Zurich D-INFK, CAB F.78, Universitaetstr. 6, CH-8092 Zurich,
13     Attn: Systems Group.
14 -}
15
16 {-# LANGUAGE MultiParamTypeClasses #-}
17 {-# LANGUAGE FlexibleInstances #-}
18 {-# LANGUAGE FlexibleContexts #-}
19
20 module SockeyeNetBuilder
21 ( sockeyeBuildNet ) where
22
23 import Control.Monad.State
24
25 import Data.Either
26 import Data.List (nub, intercalate, sort)
27 import Data.Map (Map)
28 import qualified Data.Map as Map
29 import Data.Maybe (catMaybes, fromMaybe, maybe)
30 import Data.Set (Set)
31 import qualified Data.Set as Set
32
33 import Numeric (showHex)
34
35 import qualified SockeyeAST as AST
36 import qualified SockeyeASTDecodingNet as NetAST
37
38 import Debug.Trace
39
40 type NetNodeDecl = (NetAST.NodeId, NetAST.NodeSpec)
41 type NetList = [NetNodeDecl]
42 type PortList = [NetAST.NodeId]
43 type PortMap = [(String, NetAST.NodeId)]
44
45 data FailedCheck
46     = ModuleInstLoop [String]
47     | DuplicateInPort !String !String
48     | DuplicateInMap !String !String
49     | UndefinedInPort !String !String
50     | DuplicateOutPort !String !String
51     | DuplicateOutMap !String !String
52     | UndefinedOutPort !String !String
53     | DuplicateIdentifer !String
54     | UndefinedReference !String
55
56 instance Show FailedCheck where
57     show (ModuleInstLoop loop) = concat ["Module instantiation loop:'", intercalate "' -> '" loop, "'"]
58     show (DuplicateInPort  modName port) = concat ["Multiple declarations of input port '", port, "' in '", modName, "'"]
59     show (DuplicateInMap   ns      port) = concat ["Multiple mappings for input port '", port, "' in '", ns, "'"]
60     show (UndefinedInPort  modName port) = concat ["'", port, "' is not an input port in '", modName, "'"]
61     show (DuplicateOutPort modName port) = concat ["Multiple declarations of output port '", port, "' in '", modName, "'"]
62     show (DuplicateOutMap   ns      port) = concat ["Multiple mappings for output port '", port, "' in '", ns, "'"]
63     show (UndefinedOutPort modName port) = concat ["'", port, "' is not an output port in '", modName, "'"]
64     show (DuplicateIdentifer ident)   = concat ["Multiple declarations of node '", show ident, "'"]
65     show (UndefinedReference ident)   = concat ["Reference to undefined node '", show ident, "'"]
66
67 newtype CheckFailure = CheckFailure
68     { failures :: [FailedCheck] }
69
70 instance Show CheckFailure where
71     show (CheckFailure fs) = unlines $ "":(map show fs)
72
73 data Context = Context
74     { spec         :: AST.SockeyeSpec
75     , modulePath   :: [String]
76     , curNamespace :: NetAST.Namespace
77     , paramValues  :: Map String Integer
78     , varValues    :: Map String Integer
79     , inPortMaps   :: Map String NetAST.NodeId
80     , outPortMaps  :: Map String NetAST.NodeId
81     , mappedBlocks :: [NetAST.BlockSpec]
82     }
83
84 sockeyeBuildNet :: AST.SockeyeSpec -> Either CheckFailure NetAST.NetSpec
85 sockeyeBuildNet ast = do
86     let
87         context = Context
88             { spec         = AST.SockeyeSpec Map.empty
89             , modulePath   = []
90             , curNamespace = NetAST.Namespace []
91             , paramValues  = Map.empty
92             , varValues    = Map.empty
93             , inPortMaps   = Map.empty
94             , outPortMaps  = Map.empty
95             , mappedBlocks = []
96             }        
97     net <- transform context ast
98     check Set.empty net
99     return net
100 --            
101 -- Build net
102 --
103 class NetTransformable a b where
104     transform :: Context -> a -> Either CheckFailure b
105
106 instance NetTransformable AST.SockeyeSpec NetAST.NetSpec where
107     transform context ast = do
108         let
109             rootInst = AST.ModuleInst
110                 { AST.namespace  = AST.SimpleIdent ""
111                 , AST.moduleName = "@root"
112                 , AST.arguments  = Map.empty
113                 , AST.inPortMap  = []
114                 , AST.outPortMap = []
115                 }
116             specContext = context
117                 { spec = ast }
118         netList <- transform specContext rootInst
119         let
120             nodeIds = map fst netList
121         checkDuplicates nodeIds DuplicateIdentifer
122         let
123             nodeMap = Map.fromList netList
124         return $ NetAST.NetSpec nodeMap
125
126 instance NetTransformable AST.Module NetList where
127     transform context ast = do
128         let
129             inPorts = AST.inputPorts ast
130             outPorts = AST.outputPorts ast
131             nodeDecls = AST.nodeDecls ast
132             moduleInsts = AST.moduleInsts ast
133         inDecls <- do
134             net <- transform context inPorts
135             return $ concat (net :: [NetList])
136         outDecls <- do
137             net <- transform context outPorts
138             return $ concat (net :: [NetList])
139         -- TODO check duplicate ports
140         -- TODO check mappings to non existing port
141         netDecls <- transform context nodeDecls
142         netInsts <- transform context moduleInsts
143         return $ concat (inDecls:outDecls:netDecls ++ netInsts)            
144
145 instance NetTransformable AST.Port NetList where
146     transform context (AST.MultiPort for) = do
147         netPorts <- transform context for
148         return $ concat (netPorts :: [NetList])
149     transform context (AST.InputPort portId portWidth) = do
150         netPortId <- transform context portId
151         let
152             portMap = inPortMaps context
153             name = NetAST.name netPortId
154             mappedId = Map.lookup name portMap
155         case mappedId of
156             Nothing    -> return []
157             Just ident -> do
158                 let
159                     node = portNode netPortId portWidth
160                 return [(ident, node)]
161     transform context (AST.OutputPort portId portWidth) = do
162         netPortId <- transform context portId
163         let
164             portMap = outPortMaps context
165             name = NetAST.name netPortId
166             mappedId = Map.lookup name portMap
167         case mappedId of
168             Nothing    -> return [(netPortId, portNodeTemplate)]
169             Just ident -> do
170                 let
171                     node = portNode ident portWidth
172                 return [(netPortId, node)]
173
174 portNode :: NetAST.NodeId -> Integer -> NetAST.NodeSpec
175 portNode destNode width =
176     let
177         base = NetAST.Address 0
178         limit = NetAST.Address $ 2^width - 1
179         srcBlock = NetAST.BlockSpec
180             { NetAST.base  = base
181             , NetAST.limit = limit
182             }
183         map = NetAST.MapSpec
184                 { NetAST.srcBlock = srcBlock
185                 , NetAST.destNode = destNode
186                 , NetAST.destBase = base
187                 }
188     in portNodeTemplate { NetAST.translate = [map] }
189
190 portNodeTemplate :: NetAST.NodeSpec
191 portNodeTemplate = NetAST.NodeSpec
192     { NetAST.nodeType  = NetAST.Other
193     , NetAST.accept    = []
194     , NetAST.translate = []
195     }    
196
197 instance NetTransformable AST.ModuleInst NetList where
198     transform context (AST.MultiModuleInst for) = do
199         net <- transform context for
200         return $ concat (net :: [NetList])
201     transform context ast = do
202         let
203             namespace = AST.namespace ast
204             name = AST.moduleName ast
205             args = AST.arguments ast
206             inPortMap = AST.inPortMap ast
207             outPortMap = AST.outPortMap ast
208             mod = getModule context name
209         checkSelfInst name
210         netNamespace <- transform context namespace
211         netArgs <- transform context args
212         netInMap <- transform context inPortMap
213         netOutMap <- transform context outPortMap
214         let
215             inMaps = concat (netInMap :: [PortMap])
216             outMaps = concat (netOutMap :: [PortMap])
217         checkDuplicates (map fst inMaps) (DuplicateInMap $ show netNamespace) 
218         checkDuplicates (map fst outMaps) (DuplicateOutMap $ show netNamespace)
219         let
220             modContext = moduleContext name netNamespace netArgs inMaps outMaps
221         transform modContext mod
222             where
223                 moduleContext name namespace args inMaps outMaps =
224                     let
225                         path = modulePath context
226                         base = NetAST.ns $ NetAST.namespace namespace
227                         newNs = case NetAST.name namespace of
228                             "" -> NetAST.Namespace base
229                             n  -> NetAST.Namespace $ n:base
230                     in context
231                         { modulePath   = name:path
232                         , curNamespace = newNs
233                         , paramValues  = args
234                         , varValues    = Map.empty
235                         , inPortMaps   = Map.fromList inMaps
236                         , outPortMaps  = Map.fromList outMaps
237                         }
238                 checkSelfInst name = do
239                     let
240                         path = modulePath context
241                     case loop path of
242                         [] -> return ()
243                         l  -> Left $ CheckFailure [ModuleInstLoop (reverse $ name:l)]
244                         where
245                             loop [] = []
246                             loop path@(p:ps)
247                                 | name `elem` path = p:(loop ps)
248                                 | otherwise = []
249
250
251 instance NetTransformable AST.PortMap PortMap where
252     transform context (AST.MultiPortMap for) = do
253         ts <- transform context for
254         return $ concat (ts :: [PortMap])
255     transform context ast = do
256         let
257             mappedId = AST.mappedId ast
258             mappedPort = AST.mappedPort ast
259         netMappedId <- transform context mappedId
260         netMappedPort <- transform context mappedPort
261         return [(NetAST.name netMappedPort, netMappedId)]
262
263 instance NetTransformable AST.ModuleArg Integer where
264     transform context (AST.AddressArg value) = return value
265     transform context (AST.NaturalArg value) = return value
266     transform context (AST.ParamArg name) = return $ getParamValue context name
267
268 instance NetTransformable AST.Identifier NetAST.NodeId where
269     transform context ast = do
270         let
271             namespace = curNamespace context
272             name = identName ast
273         return NetAST.NodeId
274             { NetAST.namespace = namespace
275             , NetAST.name      = name
276             }
277             where
278                 identName (AST.SimpleIdent name) = name
279                 identName ident =
280                     let
281                         prefix = AST.prefix ident
282                         varName = AST.varName ident
283                         suffix = AST.suffix ident
284                         varValue = show $ getVarValue context varName
285                         suffixName = case suffix of
286                             Nothing -> ""
287                             Just s  -> identName s
288                     in prefix ++ varValue ++ suffixName
289
290 instance NetTransformable AST.NodeDecl NetList where
291     transform context (AST.MultiNodeDecl for) = do
292         ts <- transform context for
293         return $ concat (ts :: [NetList])
294     transform context ast = do
295         let
296             ident = AST.nodeId ast
297             nodeSpec = AST.nodeSpec ast
298         nodeId <- transform context ident
299         netNodeSpec <- transform context nodeSpec
300         return [(nodeId, netNodeSpec)]
301
302 instance NetTransformable AST.NodeSpec NetAST.NodeSpec where
303     transform context ast = do
304         let
305             nodeType = AST.nodeType ast
306             accept = AST.accept ast
307             translate = AST.translate ast
308             overlay = AST.overlay ast
309         netNodeType <- maybe (return NetAST.Other) (transform context) nodeType
310         netAccept <- transform context accept
311         netTranslate <- transform context translate
312         let
313             mapBlocks = map NetAST.srcBlock netTranslate
314             nodeContext = context
315                 { mappedBlocks = netAccept ++ mapBlocks }
316         netOverlay <- case overlay of
317                 Nothing -> return []
318                 Just o  -> transform nodeContext o
319         return NetAST.NodeSpec
320             { NetAST.nodeType  = netNodeType
321             , NetAST.accept    = netAccept
322             , NetAST.translate = netTranslate ++ netOverlay
323             }
324
325 instance NetTransformable AST.NodeType NetAST.NodeType where
326     transform _ AST.Memory = return NetAST.Memory
327     transform _ AST.Device = return NetAST.Device
328
329 instance NetTransformable AST.BlockSpec NetAST.BlockSpec where
330     transform context (AST.SingletonBlock address) = do
331         netAddress <- transform context address
332         return NetAST.BlockSpec
333             { NetAST.base  = netAddress
334             , NetAST.limit = netAddress
335             }
336     transform context (AST.RangeBlock base limit) = do
337         netBase <- transform context base
338         netLimit <- transform context limit
339         return NetAST.BlockSpec
340             { NetAST.base  = netBase
341             , NetAST.limit = netLimit
342             }
343     transform context (AST.LengthBlock base bits) = do
344         netBase <- transform context base
345         let
346             baseAddress = NetAST.address netBase
347             limit = baseAddress + 2^bits - 1
348             netLimit = NetAST.Address limit
349         return NetAST.BlockSpec
350             { NetAST.base  = netBase
351             , NetAST.limit = netLimit
352             }
353
354 instance NetTransformable AST.MapSpec NetAST.MapSpec where
355     transform context ast = do
356         let
357             block = AST.block ast
358             destNode = AST.destNode ast
359             destBase = fromMaybe (AST.LiteralAddress 0) (AST.destBase ast)
360         netBlock <- transform context block
361         netDestNode <- transform context destNode
362         netDestBase <- transform context destBase
363         return NetAST.MapSpec
364             { NetAST.srcBlock = netBlock
365             , NetAST.destNode = netDestNode
366             , NetAST.destBase = netDestBase
367             }
368
369 instance NetTransformable AST.OverlaySpec [NetAST.MapSpec] where
370     transform context ast = do
371         let
372             over = AST.over ast
373             width = AST.width ast
374             blocks = mappedBlocks context
375         netOver <- transform context over
376         let
377             maps = overlayMaps netOver width blocks
378         return maps
379
380 overlayMaps :: NetAST.NodeId -> Integer ->[NetAST.BlockSpec] -> [NetAST.MapSpec]
381 overlayMaps destId width blocks =
382     let
383         blockPoints = concat $ map toScanPoints blocks
384         maxAddress = 2^width
385         overStop  = BlockStart $ maxAddress
386         scanPoints = filter ((maxAddress >=) . address) $ sort (overStop:blockPoints)
387         startState = ScanLineState
388             { insideBlocks    = 0
389             , startAddress    = 0
390             }
391     in evalState (scanLine scanPoints []) startState
392     where
393         toScanPoints (NetAST.BlockSpec base limit) =
394                 [ BlockStart $ NetAST.address base
395                 , BlockEnd   $ NetAST.address limit
396                 ]
397         scanLine [] ms = return ms
398         scanLine (p:ps) ms = do
399             maps <- pointAction p ms
400             scanLine ps maps
401         pointAction (BlockStart a) ms = do
402             s <- get       
403             let
404                 i = insideBlocks s
405                 base = startAddress s
406                 limit = a - 1
407             maps <- if (i == 0) && (base <= limit)
408                 then
409                     let
410                         baseAddress = NetAST.Address $ startAddress s
411                         limitAddress = NetAST.Address $ a - 1
412                         srcBlock = NetAST.BlockSpec baseAddress limitAddress
413                         m = NetAST.MapSpec srcBlock destId baseAddress
414                     in return $ m:ms
415                 else return ms
416             modify (\s -> s { insideBlocks = i + 1})
417             return maps
418         pointAction (BlockEnd a) ms = do
419             s <- get
420             let
421                 i = insideBlocks s
422             put $ ScanLineState (i - 1) (a + 1)
423             return ms
424
425 data StoppingPoint
426     = BlockStart { address :: !Integer }
427     | BlockEnd   { address :: !Integer }
428     deriving (Eq, Show)
429
430 instance Ord StoppingPoint where
431     (<=) (BlockStart a1) (BlockEnd   a2)
432         | a1 == a2 = True
433         | otherwise = a1 <= a2
434     (<=) (BlockEnd   a1) (BlockStart a2)
435         | a1 == a2 = False
436         | otherwise = a1 <= a2
437     (<=) sp1 sp2 = (address sp1) <= (address sp2)
438
439 data ScanLineState
440     = ScanLineState
441         { insideBlocks :: !Integer
442         , startAddress :: !Integer
443         } deriving (Show)
444
445 instance NetTransformable AST.Address NetAST.Address where
446     transform _ (AST.LiteralAddress value) = do
447         return $ NetAST.Address value
448     transform context (AST.ParamAddress name) = do
449         let
450             value = getParamValue context name
451         return $ NetAST.Address value
452
453 instance NetTransformable a b => NetTransformable (AST.For a) [b] where
454     transform context ast = do
455         let
456             body = AST.body ast
457             varRanges = AST.varRanges ast
458         concreteRanges <- transform context varRanges
459         let
460             valueList = Map.foldWithKey iterations [] concreteRanges
461             iterContexts = map iterationContext valueList
462             ts = map (\c -> transform c body) iterContexts
463             fs = lefts ts
464             bs = rights ts
465         case fs of
466             [] -> return $ bs
467             _  -> Left $ CheckFailure (concat $ map failures fs)
468         where
469             iterations k vs [] = [Map.fromList [(k,v)] | v <- vs]
470             iterations k vs ms = concat $ map (f ms k) vs
471                 where
472                     f ms k v = map (Map.insert k v) ms
473             iterationContext varMap =
474                 let
475                     values = varValues context
476                 in context
477                     { varValues = values `Map.union` varMap }
478
479 instance NetTransformable AST.ForRange [Integer] where
480     transform context ast = do
481         let
482             start = AST.start ast
483             end = AST.end ast
484         startVal <- transform context start
485         endVal <- transform context end
486         return [startVal..endVal]
487
488 instance NetTransformable AST.ForLimit Integer where
489     transform _ (AST.LiteralLimit value) = return value
490     transform context (AST.ParamLimit name) = return $ getParamValue context name
491
492 instance NetTransformable a b => NetTransformable [a] [b] where
493     transform context ast = do
494         let
495             ts = map (transform context) ast
496             fs = lefts ts
497             bs = rights ts
498         case fs of
499             [] -> return bs
500             _  -> Left $ CheckFailure (concat $ map failures fs)
501
502 instance (Ord k, NetTransformable a b) => NetTransformable (Map k a) (Map k b) where
503     transform context ast = do
504         let
505             ks = Map.keys ast
506             es = Map.elems ast
507         ts <- transform context es
508         return $ Map.fromList (zip ks ts)
509
510 --
511 -- Checks
512 --
513 class NetCheckable a where
514     check :: Set NetAST.NodeId -> a -> Either CheckFailure ()
515
516 instance NetCheckable NetAST.NetSpec where
517     check context (NetAST.NetSpec net) = do
518         let
519             specContext = Map.keysSet net
520         check specContext $ Map.elems net
521
522 instance NetCheckable NetAST.NodeSpec where
523     check context net = do
524         let
525             translate = NetAST.translate net
526         check context translate
527
528 instance NetCheckable NetAST.MapSpec where
529     check context net = do
530         let
531            destNode = NetAST.destNode net
532         check context destNode
533
534 instance NetCheckable NetAST.NodeId where
535     check context net = do
536         if net `Set.member` context
537             then return ()
538             else Left $ CheckFailure [UndefinedReference $ show net]
539
540 instance NetCheckable a => NetCheckable [a] where
541     check context net = do
542         let
543             checked = map (check context) net
544             fs = lefts $ checked
545         case fs of
546             [] -> return ()
547             _  -> Left $ CheckFailure (concat $ map failures fs)
548
549 getModule :: Context -> String -> AST.Module
550 getModule context name =
551     let
552         modules = AST.modules $ spec context
553     in modules Map.! name
554
555 getParamValue :: Context -> String -> Integer
556 getParamValue context name =
557     let
558         params = paramValues context
559     in params Map.! name
560
561 getVarValue :: Context -> String -> Integer
562 getVarValue context name =
563     let
564         vars = varValues context
565     in vars Map.! name
566
567 checkDuplicates :: (Eq a, Show a) => [a] -> (String -> FailedCheck) -> Either CheckFailure ()
568 checkDuplicates nodeIds fail = do
569     let
570         duplicates = duplicateNames nodeIds
571     case duplicates of
572         [] -> return ()
573         _  -> Left $ CheckFailure (map (fail . show) duplicates)
574     where
575         duplicateNames [] = []
576         duplicateNames (x:xs)
577             | x `elem` xs = nub $ [x] ++ duplicateNames xs
578             | otherwise = duplicateNames xs