8373809c6803cebece1a5a2b1ba9437fe488f061
[barrelfish] / tools / sockeye / SockeyeNetBuilder.hs
1 {-
2     SockeyeNetBuilder.hs: Decoding net builder for Sockeye
3
4     Part of Sockeye
5
6     Copyright (c) 2017, ETH Zurich.
7
8     All rights reserved.
9
10     This file is distributed under the terms in the attached LICENSE file.
11     If you do not find this file, copies can be found by writing to:
12     ETH Zurich D-INFK, CAB F.78, Universitaetstr. 6, CH-8092 Zurich,
13     Attn: Systems Group.
14 -}
15
16 {-# LANGUAGE MultiParamTypeClasses #-}
17 {-# LANGUAGE FlexibleInstances #-}
18 {-# LANGUAGE FlexibleContexts #-}
19
20 module SockeyeNetBuilder
21 ( sockeyeBuildNet ) where
22
23 import Control.Monad.State
24
25 import Data.Either
26 import Data.List (nub, intercalate)
27 import Data.Map (Map)
28 import qualified Data.Map as Map
29 import Data.Maybe (catMaybes, fromMaybe, maybe)
30 import Data.Set (Set)
31 import qualified Data.Set as Set
32
33 import Numeric (showHex)
34
35 import qualified SockeyeAST as AST
36 import qualified SockeyeASTDecodingNet as NetAST
37
38 type NetNodeDecl = (NetAST.NodeId, NetAST.NodeSpec)
39 type NetList = [NetNodeDecl]
40 type PortList = [NetAST.NodeId]
41 type PortMap = [(String, NetAST.NodeId)]
42
43 data FailedCheck
44     = ModuleInstLoop [String]
45     | DuplicateInPort !String !String
46     | DuplicateInMap !String !String
47     | UndefinedInPort !String !String
48     | DuplicateOutPort !String !String
49     | DuplicateOutMap !String !String
50     | UndefinedOutPort !String !String
51     | DuplicateIdentifer !String
52     | UndefinedReference !String
53
54 instance Show FailedCheck where
55     show (ModuleInstLoop loop) = concat ["Module instantiation loop :'", intercalate "' -> '" loop, "'"]
56     show (DuplicateInPort  modName port) = concat ["Multiple declarations of input port '", port, "' in '", modName, "'"]
57     show (DuplicateInMap   ns      port) = concat ["Multiple mappings for input port '", port, "' in '", ns, "'"]
58     show (UndefinedInPort  modName port) = concat ["'", port, "' is not an input port in '", modName, "'"]
59     show (DuplicateOutPort modName port) = concat ["Multiple declarations of output port '", port, "' in '", modName, "'"]
60     show (DuplicateOutMap   ns      port) = concat ["Multiple mappings for output port '", port, "' in '", ns, "'"]
61     show (UndefinedOutPort modName port) = concat ["'", port, "' is not an output port in '", modName, "'"]
62     show (DuplicateIdentifer ident)   = concat ["Multiple declarations of node '", show ident, "'"]
63     show (UndefinedReference ident)   = concat ["Reference to undefined node '", show ident, "'"]
64
65 newtype CheckFailure = CheckFailure
66     { failures :: [FailedCheck] }
67
68 instance Show CheckFailure where
69     show (CheckFailure fs) = unlines $ "":(map show fs)
70
71 data Context = Context
72     { spec         :: AST.SockeyeSpec
73     , modulePath   :: [String]
74     , curNamespace :: NetAST.Namespace
75     , paramValues  :: Map String Word
76     , varValues    :: Map String Word
77     , inPortMaps   :: Map String NetAST.NodeId
78     , outPortMaps  :: Map String NetAST.NodeId
79     }
80
81 sockeyeBuildNet :: AST.SockeyeSpec -> Either CheckFailure NetAST.NetSpec
82 sockeyeBuildNet ast = do
83     let
84         context = Context
85             { spec         = AST.SockeyeSpec Map.empty
86             , modulePath   = []
87             , curNamespace = NetAST.Namespace []
88             , paramValues  = Map.empty
89             , varValues    = Map.empty
90             , inPortMaps   = Map.empty
91             , outPortMaps  = Map.empty
92             }        
93     net <- transform context ast
94     check Set.empty net
95     return net
96 --            
97 -- Build net
98 --
99 class NetTransformable a b where
100     transform :: Context -> a -> Either CheckFailure b
101
102 instance NetTransformable AST.SockeyeSpec NetAST.NetSpec where
103     transform context ast = do
104         let
105             rootInst = AST.ModuleInst
106                 { AST.namespace  = AST.SimpleIdent ""
107                 , AST.moduleName = "@root"
108                 , AST.arguments  = Map.empty
109                 , AST.inPortMap  = []
110                 , AST.outPortMap = []
111                 }
112             specContext = context
113                 { spec = ast }
114         netList <- transform specContext rootInst
115         let
116             nodeIds = map fst netList
117         checkDuplicates nodeIds DuplicateIdentifer
118         let
119             nodeMap = Map.fromList netList
120         return $ NetAST.NetSpec nodeMap
121
122 instance NetTransformable AST.Module NetList where
123     transform context ast = do
124         let
125             inPorts = AST.inputPorts ast
126             outPorts = AST.outputPorts ast
127             nodeDecls = AST.nodeDecls ast
128             moduleInsts = AST.moduleInsts ast
129         inDecls <- do
130             net <- transform context inPorts
131             return $ concat (net :: [NetList])
132         outDecls <- do
133             net <- transform context outPorts
134             return $ concat (net :: [NetList])
135         -- TODO check duplicate ports
136         -- TODO check mappings to non existing port
137         netDecls <- transform context nodeDecls
138         netInsts <- transform context moduleInsts
139         return $ concat (inDecls:outDecls:netDecls ++ netInsts)
140         where
141             nameWithArgs =
142                 let
143                     name = head $ modulePath context
144                     paramNames = AST.paramNames ast
145                     paramTypes = AST.paramTypeMap ast
146                     params = map (\p -> (p, paramTypes Map.! p)) paramNames
147                     argValues = map showValue params
148                 in concat [name, "(", intercalate ", " argValues, ")"]
149                 where
150                     showValue (name, AST.AddressParam) = "0x" ++ showHex (getParamValue context name) ""
151                     showValue (name, AST.NaturalParam) = show (getParamValue context name)
152             
153
154 instance NetTransformable AST.Port NetList where
155     transform context (AST.MultiPort for) = do
156         netPorts <- transform context for
157         return $ concat (netPorts :: [NetList])
158     transform context (AST.InputPort ident) = do
159         netIdent <- transform context ident
160         let
161             portMap = inPortMaps context
162             decl = mapPort portMap netIdent
163         return $ catMaybes [decl]
164             where
165                 mapPort portMap port = do
166                     let
167                         name = NetAST.name port
168                     mappedId <- Map.lookup name portMap
169                     return (mappedId, portMapTemplate { NetAST.overlay = Just port })
170     transform context (AST.OutputPort ident) = do
171         netIdent <- transform context ident
172         let
173             portMap = outPortMaps context
174             decl = mapPort portMap netIdent
175         return [decl]
176             where
177                 mapPort portMap port = let
178                     name = NetAST.name port
179                     mappedId = Map.lookup name portMap
180                     in (port, portMapTemplate { NetAST.overlay = mappedId })
181
182 portMapTemplate :: NetAST.NodeSpec
183 portMapTemplate = NetAST.NodeSpec
184     { NetAST.nodeType  = NetAST.Other
185     , NetAST.accept    = []
186     , NetAST.translate = []
187     , NetAST.overlay   = Nothing
188     }
189
190 instance NetTransformable AST.ModuleInst NetList where
191     transform context (AST.MultiModuleInst for) = do
192         net <- transform context for
193         return $ concat (net :: [NetList])
194     transform context ast = do
195         let
196             namespace = AST.namespace ast
197             name = AST.moduleName ast
198             args = AST.arguments ast
199             inPortMap = AST.inPortMap ast
200             outPortMap = AST.outPortMap ast
201             mod = getModule context name
202         checkSelfInst name
203         netNamespace <- transform context namespace
204         netArgs <- transform context args
205         netInMap <- transform context inPortMap
206         netOutMap <- transform context outPortMap
207         let
208             inMaps = concat (netInMap :: [PortMap])
209             outMaps = concat (netOutMap :: [PortMap])
210         checkDuplicates (map fst inMaps) (DuplicateInMap $ show netNamespace) 
211         checkDuplicates (map fst outMaps) (DuplicateOutMap $ show netNamespace)
212         let
213             modContext = moduleContext name netNamespace netArgs inMaps outMaps
214         transform modContext mod
215             where
216                 moduleContext name namespace args inMaps outMaps =
217                     let
218                         path = modulePath context
219                         base = NetAST.ns $ NetAST.namespace namespace
220                         newNs = case NetAST.name namespace of
221                             "" -> NetAST.Namespace base
222                             n  -> NetAST.Namespace $ base ++ [n]
223                     in context
224                         { modulePath   = name:path
225                         , curNamespace = newNs
226                         , paramValues  = args
227                         , varValues    = Map.empty
228                         , inPortMaps   = Map.fromList inMaps
229                         , outPortMaps  = Map.fromList outMaps
230                         }
231                 checkSelfInst name = do
232                     let
233                         path = modulePath context
234                     case loop path of
235                         [] -> return ()
236                         l  -> Left $ CheckFailure [ModuleInstLoop (reverse $ name:l)]
237                         where
238                             loop [] = []
239                             loop path@(p:ps)
240                                 | name `elem` path = p:(loop ps)
241                                 | otherwise = []
242
243
244 instance NetTransformable AST.PortMap PortMap where
245     transform context (AST.MultiPortMap for) = do
246         ts <- transform context for
247         return $ concat (ts :: [PortMap])
248     transform context ast = do
249         let
250             mappedId = AST.mappedId ast
251             mappedPort = AST.mappedPort ast
252         netMappedId <- transform context mappedId
253         netMappedPort <- transform context mappedPort
254         return [(NetAST.name netMappedPort, netMappedId)]
255
256 instance NetTransformable AST.ModuleArg Word where
257     transform context (AST.AddressArg value) = return value
258     transform context (AST.NaturalArg value) = return value
259     transform context (AST.ParamArg name) = return $ getParamValue context name
260
261 instance NetTransformable AST.Identifier NetAST.NodeId where
262     transform context ast = do
263         let
264             namespace = curNamespace context
265             name = identName ast
266         return NetAST.NodeId
267             { NetAST.namespace = namespace
268             , NetAST.name      = name
269             }
270             where
271                 identName (AST.SimpleIdent name) = name
272                 identName ident =
273                     let
274                         prefix = AST.prefix ident
275                         varName = AST.varName ident
276                         suffix = AST.suffix ident
277                         varValue = show $ getVarValue context varName
278                         suffixName = case suffix of
279                             Nothing -> ""
280                             Just s  -> identName s
281                     in prefix ++ varValue ++ suffixName
282
283 instance NetTransformable AST.NodeDecl NetList where
284     transform context (AST.MultiNodeDecl for) = do
285         ts <- transform context for
286         return $ concat (ts :: [NetList])
287     transform context ast = do
288         let
289             ident = AST.nodeId ast
290             nodeSpec = AST.nodeSpec ast
291         nodeId <- transform context ident
292         netNodeSpec <- transform context nodeSpec
293         return [(nodeId, netNodeSpec)]
294
295 instance NetTransformable AST.NodeSpec NetAST.NodeSpec where
296     transform context ast = do
297         let
298             nodeType = AST.nodeType ast
299             accept = AST.accept ast
300             translate = AST.translate ast
301             overlay = AST.overlay ast
302         netNodeType <- maybe (return NetAST.Other) (transform context) nodeType
303         netAccept <- transform context accept
304         netTranslate <- transform context translate
305         netOverlay <- case overlay of
306                 Nothing -> return Nothing
307                 Just o  -> do 
308                     t <- transform context o
309                     return $ Just t
310         return NetAST.NodeSpec
311             { NetAST.nodeType  = netNodeType
312             , NetAST.accept    = netAccept
313             , NetAST.translate = netTranslate
314             , NetAST.overlay   = netOverlay
315             }
316
317 instance NetTransformable AST.NodeType NetAST.NodeType where
318     transform _ AST.Memory = return NetAST.Memory
319     transform _ AST.Device = return NetAST.Device
320
321 instance NetTransformable AST.BlockSpec NetAST.BlockSpec where
322     transform context (AST.SingletonBlock address) = do
323         netAddress <- transform context address
324         return NetAST.BlockSpec
325             { NetAST.base  = netAddress
326             , NetAST.limit = netAddress
327             }
328     transform context (AST.RangeBlock base limit) = do
329         netBase <- transform context base
330         netLimit <- transform context limit
331         return NetAST.BlockSpec
332             { NetAST.base  = netBase
333             , NetAST.limit = netLimit
334             }
335     transform context (AST.LengthBlock base bits) = do
336         netBase <- transform context base
337         let
338             baseAddress = NetAST.address netBase
339             limit = baseAddress + 2^bits - 1
340             netLimit = NetAST.Address limit
341         return NetAST.BlockSpec
342             { NetAST.base  = netBase
343             , NetAST.limit = netLimit
344             }
345
346 instance NetTransformable AST.MapSpec NetAST.MapSpec where
347     transform context ast = do
348         let
349             block = AST.block ast
350             destNode = AST.destNode ast
351             destBase = fromMaybe (AST.base block) (AST.destBase ast)
352         netBlock <- transform context block
353         netDestNode <- transform context destNode
354         netDestBase <- transform context destBase
355         return NetAST.MapSpec
356             { NetAST.srcBlock = netBlock
357             , NetAST.destNode = netDestNode
358             , NetAST.destBase = netDestBase
359             }
360
361 instance NetTransformable AST.Address NetAST.Address where
362     transform _ (AST.LiteralAddress value) = do
363         return $ NetAST.Address value
364     transform context (AST.ParamAddress name) = do
365         let
366             value = getParamValue context name
367         return $ NetAST.Address value
368
369 instance NetTransformable a b => NetTransformable (AST.For a) [b] where
370     transform context ast = do
371         let
372             body = AST.body ast
373             varRanges = AST.varRanges ast
374         concreteRanges <- transform context varRanges
375         let
376             valueList = Map.foldWithKey iterations [] concreteRanges
377             iterContexts = map iterationContext valueList
378             ts = map (\c -> transform c body) iterContexts
379             fs = lefts ts
380             bs = rights ts
381         case fs of
382             [] -> return $ bs
383             _  -> Left $ CheckFailure (concat $ map failures fs)
384         where
385             iterations k vs [] = [Map.fromList [(k,v)] | v <- vs]
386             iterations k vs ms = concat $ map (f ms k) vs
387                 where
388                     f ms k v = map (Map.insert k v) ms
389             iterationContext varMap =
390                 let
391                     values = varValues context
392                 in context
393                     { varValues = values `Map.union` varMap }
394
395 instance NetTransformable AST.ForRange [Word] where
396     transform context ast = do
397         let
398             start = AST.start ast
399             end = AST.end ast
400         startVal <- transform context start
401         endVal <- transform context end
402         return [startVal..endVal]
403
404 instance NetTransformable AST.ForLimit Word where
405     transform _ (AST.LiteralLimit value) = return value
406     transform context (AST.ParamLimit name) = return $ getParamValue context name
407
408 instance NetTransformable a b => NetTransformable [a] [b] where
409     transform context ast = do
410         let
411             ts = map (transform context) ast
412             fs = lefts ts
413             bs = rights ts
414         case fs of
415             [] -> return bs
416             _  -> Left $ CheckFailure (concat $ map failures fs)
417
418 instance (Ord k, NetTransformable a b) => NetTransformable (Map k a) (Map k b) where
419     transform context ast = do
420         let
421             ks = Map.keys ast
422             es = Map.elems ast
423         ts <- transform context es
424         return $ Map.fromList (zip ks ts)
425
426 --
427 -- Checks
428 --
429 class NetCheckable a where
430     check :: Set NetAST.NodeId -> a -> Either CheckFailure ()
431
432 instance NetCheckable NetAST.NetSpec where
433     check context (NetAST.NetSpec net) = do
434         let
435             specContext = Map.keysSet net
436         check specContext $ Map.elems net
437
438 instance NetCheckable NetAST.NodeSpec where
439     check context net = do
440         let
441             translate = NetAST.translate net
442             overlay = NetAST.overlay net
443         check context translate
444         maybe (return ()) (check context) overlay
445
446 instance NetCheckable NetAST.MapSpec where
447     check context net = do
448         let
449            destNode = NetAST.destNode net
450         check context destNode
451
452 instance NetCheckable NetAST.NodeId where
453     check context net = do
454         if net `Set.member` context
455             then return ()
456             else Left $ CheckFailure [UndefinedReference $ show net]
457
458 instance NetCheckable a => NetCheckable [a] where
459     check context net = do
460         let
461             checked = map (check context) net
462             fs = lefts $ checked
463         case fs of
464             [] -> return ()
465             _  -> Left $ CheckFailure (concat $ map failures fs)
466
467 getModule :: Context -> String -> AST.Module
468 getModule context name =
469     let
470         modules = AST.modules $ spec context
471     in modules Map.! name
472
473 getParamValue :: Context -> String -> Word
474 getParamValue context name =
475     let
476         params = paramValues context
477     in params Map.! name
478
479 getVarValue :: Context -> String -> Word
480 getVarValue context name =
481     let
482         vars = varValues context
483     in vars Map.! name
484
485 checkDuplicates :: (Eq a, Show a) => [a] -> (String -> FailedCheck) -> Either CheckFailure ()
486 checkDuplicates nodeIds fail = do
487     let
488         duplicates = duplicateNames nodeIds
489     case duplicates of
490         [] -> return ()
491         _  -> Left $ CheckFailure (map (fail . show) duplicates)
492     where
493         duplicateNames [] = []
494         duplicateNames (x:xs)
495             | x `elem` xs = nub $ [x] ++ duplicateNames xs
496             | otherwise = duplicateNames xs