Sockeye: Implement port mappings
[barrelfish] / tools / sockeye / SockeyeNetBuilder.hs
1 {-
2     SockeyeNetBuilder.hs: Decoding net builder for Sockeye
3
4     Part of Sockeye
5
6     Copyright (c) 2017, ETH Zurich.
7
8     All rights reserved.
9
10     This file is distributed under the terms in the attached LICENSE file.
11     If you do not find this file, copies can be found by writing to:
12     ETH Zurich D-INFK, CAB F.78, Universitaetstr. 6, CH-8092 Zurich,
13     Attn: Systems Group.
14 -}
15
16 {-# LANGUAGE MultiParamTypeClasses #-}
17 {-# LANGUAGE FlexibleInstances #-}
18 {-# LANGUAGE FlexibleContexts #-}
19
20 module SockeyeNetBuilder
21 ( sockeyeBuildNet ) where
22
23 import Control.Monad.State
24
25 import Data.Either
26 import Data.List (nub, intercalate)
27 import Data.Map (Map)
28 import qualified Data.Map as Map
29 import Data.Maybe (catMaybes, fromMaybe, maybe)
30 import Data.Set (Set)
31 import qualified Data.Set as Set
32
33 import Numeric (showHex)
34
35 import qualified SockeyeAST as AST
36 import qualified SockeyeASTDecodingNet as NetAST
37
38 import Text.Groom (groom)
39 import Debug.Trace
40
41 type NetNodeDecl = (NetAST.NodeId, NetAST.NodeSpec)
42 type NetList = [NetNodeDecl]
43 type PortList = [NetAST.NodeId]
44 type PortMap = [(String, NetAST.NodeId)]
45
46 data FailedCheck
47     = ModuleInstLoop [String]
48     | DuplicateInPort !String !String
49     | DuplicateInMap !String !String
50     | UndefinedInPort !String !String
51     | DuplicateOutPort !String !String
52     | DuplicateOutMap !String !String
53     | UndefinedOutPort !String !String
54     | DuplicateIdentifer !String
55     | UndefinedReference !String
56
57 instance Show FailedCheck where
58     show (ModuleInstLoop loop) = concat ["Module instantiation loop :'", intercalate "' -> '" loop, "'"]
59     show (DuplicateInPort  modName port) = concat ["Multiple declarations of input port '", port, "' in '", modName, "'"]
60     show (DuplicateInMap   ns      port) = concat ["Multiple mappings for input port '", port, "' in '", ns, "'"]
61     show (UndefinedInPort  modName port) = concat ["'", port, "' is not an input port in '", modName, "'"]
62     show (DuplicateOutPort modName port) = concat ["Multiple declarations of output port '", port, "' in '", modName, "'"]
63     show (DuplicateOutMap   ns      port) = concat ["Multiple mappings for output port '", port, "' in '", ns, "'"]
64     show (UndefinedOutPort modName port) = concat ["'", port, "' is not an output port in '", modName, "'"]
65     show (DuplicateIdentifer ident)   = concat ["Multiple declarations of node '", show ident, "'"]
66     show (UndefinedReference ident)   = concat ["Reference to undefined node '", show ident, "'"]
67
68 newtype CheckFailure = CheckFailure
69     { failures :: [FailedCheck] }
70
71 instance Show CheckFailure where
72     show (CheckFailure fs) = unlines $ "":(map show fs)
73
74 data Context = Context
75     { spec         :: AST.SockeyeSpec
76     , modulePath   :: [String]
77     , curNamespace :: NetAST.Namespace
78     , paramValues  :: Map String Word
79     , varValues    :: Map String Word
80     , inPortMaps   :: Map String NetAST.NodeId
81     , outPortMaps  :: Map String NetAST.NodeId
82     }
83
84 sockeyeBuildNet :: AST.SockeyeSpec -> Either CheckFailure NetAST.NetSpec
85 sockeyeBuildNet ast = do
86     let
87         context = Context
88             { spec         = AST.SockeyeSpec Map.empty
89             , modulePath   = []
90             , curNamespace = NetAST.Namespace []
91             , paramValues  = Map.empty
92             , varValues    = Map.empty
93             , inPortMaps   = Map.empty
94             , outPortMaps  = Map.empty
95             }        
96     net <- transform context ast
97     trace (groom net) $ return ()
98     check Set.empty net
99     return net
100 --            
101 -- Build net
102 --
103 class NetTransformable a b where
104     transform :: Context -> a -> Either CheckFailure b
105
106 instance NetTransformable AST.SockeyeSpec NetAST.NetSpec where
107     transform context ast = do
108         let
109             rootInst = AST.ModuleInst
110                 { AST.namespace  = AST.SimpleIdent ""
111                 , AST.moduleName = "@root"
112                 , AST.arguments  = Map.empty
113                 , AST.inPortMap  = []
114                 , AST.outPortMap = []
115                 }
116             specContext = context
117                 { spec = ast }
118         netList <- transform specContext rootInst
119         let
120             nodeIds = map fst netList
121         checkDuplicates nodeIds DuplicateIdentifer
122         let
123             nodeMap = Map.fromList netList
124         return $ NetAST.NetSpec nodeMap
125
126 instance NetTransformable AST.Module NetList where
127     transform context ast = do
128         let
129             inPorts = AST.inputPorts ast
130             outPorts = AST.outputPorts ast
131             nodeDecls = AST.nodeDecls ast
132             moduleInsts = AST.moduleInsts ast
133         inDecls <- do
134             net <- transform context inPorts
135             return $ concat (net :: [NetList])
136         outDecls <- do
137             net <- transform context outPorts
138             return $ concat (net :: [NetList])
139         -- TODO check duplicate ports
140         -- TODO check mappings to non existing port
141         netDecls <- transform context nodeDecls
142         netInsts <- transform context moduleInsts
143         return $ concat (inDecls:outDecls:netDecls ++ netInsts)
144         where
145             nameWithArgs =
146                 let
147                     name = head $ modulePath context
148                     paramNames = AST.paramNames ast
149                     paramTypes = AST.paramTypeMap ast
150                     params = map (\p -> (p, paramTypes Map.! p)) paramNames
151                     argValues = map showValue params
152                 in concat [name, "(", intercalate ", " argValues, ")"]
153                 where
154                     showValue (name, AST.AddressParam) = "0x" ++ showHex (getParamValue context name) ""
155                     showValue (name, AST.NaturalParam) = show (getParamValue context name)
156             
157
158 instance NetTransformable AST.Port NetList where
159     transform context (AST.MultiPort for) = do
160         netPorts <- transform context for
161         return $ concat (netPorts :: [NetList])
162     transform context (AST.InputPort ident) = do
163         netIdent <- transform context ident
164         let
165             portMap = inPortMaps context
166             decl = mapPort portMap netIdent
167         return $ catMaybes [decl]
168             where
169                 mapPort portMap port = do
170                     let
171                         name = NetAST.name port
172                     mappedId <- Map.lookup name portMap
173                     return (mappedId, portMapTemplate { NetAST.overlay = Just port })
174     transform context (AST.OutputPort ident) = do
175         netIdent <- transform context ident
176         let
177             portMap = outPortMaps context
178             decl = mapPort portMap netIdent
179         return [decl]
180             where
181                 mapPort portMap port = let
182                     name = NetAST.name port
183                     mappedId = Map.lookup name portMap
184                     in (port, portMapTemplate { NetAST.overlay = mappedId })
185
186 portMapTemplate :: NetAST.NodeSpec
187 portMapTemplate = NetAST.NodeSpec
188     { NetAST.nodeType  = NetAST.Other
189     , NetAST.accept    = []
190     , NetAST.translate = []
191     , NetAST.overlay   = Nothing
192     }
193
194 instance NetTransformable AST.ModuleInst NetList where
195     transform context (AST.MultiModuleInst for) = do
196         net <- transform context for
197         return $ concat (net :: [NetList])
198     transform context ast = do
199         let
200             namespace = AST.namespace ast
201             name = AST.moduleName ast
202             args = AST.arguments ast
203             inPortMap = AST.inPortMap ast
204             outPortMap = AST.outPortMap ast
205             mod = getModule context name
206         checkSelfInst name
207         netNamespace <- transform context namespace
208         netArgs <- transform context args
209         netInMap <- transform context inPortMap
210         netOutMap <- transform context outPortMap
211         let
212             inMaps = concat (netInMap :: [PortMap])
213             outMaps = concat (netOutMap :: [PortMap])
214         checkDuplicates (map fst inMaps) (DuplicateInMap $ show netNamespace) 
215         checkDuplicates (map fst outMaps) (DuplicateOutMap $ show netNamespace)
216         let
217             modContext = moduleContext name netNamespace netArgs inMaps outMaps
218         transform modContext mod
219             where
220                 moduleContext name namespace args inMaps outMaps =
221                     let
222                         path = modulePath context
223                         base = NetAST.ns $ NetAST.namespace namespace
224                         newNs = case NetAST.name namespace of
225                             "" -> NetAST.Namespace base
226                             n  -> NetAST.Namespace $ base ++ [n]
227                     in context
228                         { modulePath   = name:path
229                         , curNamespace = newNs
230                         , paramValues  = args
231                         , varValues    = Map.empty
232                         , inPortMaps   = Map.fromList inMaps
233                         , outPortMaps  = Map.fromList outMaps
234                         }
235                 checkSelfInst name = do
236                     let
237                         path = modulePath context
238                     case loop path of
239                         [] -> return ()
240                         l  -> Left $ CheckFailure [ModuleInstLoop (reverse $ name:l)]
241                         where
242                             loop [] = []
243                             loop path@(p:ps)
244                                 | name `elem` path = p:(loop ps)
245                                 | otherwise = []
246
247
248 instance NetTransformable AST.PortMap PortMap where
249     transform context (AST.MultiPortMap for) = do
250         ts <- transform context for
251         return $ concat (ts :: [PortMap])
252     transform context ast = do
253         let
254             mappedId = AST.mappedId ast
255             mappedPort = AST.mappedPort ast
256         netMappedId <- transform context mappedId
257         netMappedPort <- transform context mappedPort
258         return [(NetAST.name netMappedPort, netMappedId)]
259
260 instance NetTransformable AST.ModuleArg Word where
261     transform context (AST.AddressArg value) = return value
262     transform context (AST.NaturalArg value) = return value
263     transform context (AST.ParamArg name) = return $ getParamValue context name
264
265 instance NetTransformable AST.Identifier NetAST.NodeId where
266     transform context ast = do
267         let
268             namespace = curNamespace context
269             name = identName ast
270         return NetAST.NodeId
271             { NetAST.namespace = namespace
272             , NetAST.name      = name
273             }
274             where
275                 identName (AST.SimpleIdent name) = name
276                 identName ident =
277                     let
278                         prefix = AST.prefix ident
279                         varName = AST.varName ident
280                         suffix = AST.suffix ident
281                         varValue = show $ getVarValue context varName
282                         suffixName = case suffix of
283                             Nothing -> ""
284                             Just s  -> identName s
285                     in prefix ++ varValue ++ suffixName
286
287 instance NetTransformable AST.NodeDecl NetList where
288     transform context (AST.MultiNodeDecl for) = do
289         ts <- transform context for
290         return $ concat (ts :: [NetList])
291     transform context ast = do
292         let
293             ident = AST.nodeId ast
294             nodeSpec = AST.nodeSpec ast
295         nodeId <- transform context ident
296         netNodeSpec <- transform context nodeSpec
297         return [(nodeId, netNodeSpec)]
298
299 instance NetTransformable AST.NodeSpec NetAST.NodeSpec where
300     transform context ast = do
301         let
302             nodeType = AST.nodeType ast
303             accept = AST.accept ast
304             translate = AST.translate ast
305             overlay = AST.overlay ast
306         netNodeType <- maybe (return NetAST.Other) (transform context) nodeType
307         netAccept <- transform context accept
308         netTranslate <- transform context translate
309         netOverlay <- case overlay of
310                 Nothing -> return Nothing
311                 Just o  -> do 
312                     t <- transform context o
313                     return $ Just t
314         return NetAST.NodeSpec
315             { NetAST.nodeType  = netNodeType
316             , NetAST.accept    = netAccept
317             , NetAST.translate = netTranslate
318             , NetAST.overlay   = netOverlay
319             }
320
321 instance NetTransformable AST.NodeType NetAST.NodeType where
322     transform _ AST.Memory = return NetAST.Memory
323     transform _ AST.Device = return NetAST.Device
324
325 instance NetTransformable AST.BlockSpec NetAST.BlockSpec where
326     transform context (AST.SingletonBlock address) = do
327         netAddress <- transform context address
328         return NetAST.BlockSpec
329             { NetAST.base  = netAddress
330             , NetAST.limit = netAddress
331             }
332     transform context (AST.RangeBlock base limit) = do
333         netBase <- transform context base
334         netLimit <- transform context limit
335         return NetAST.BlockSpec
336             { NetAST.base  = netBase
337             , NetAST.limit = netLimit
338             }
339     transform context (AST.LengthBlock base bits) = do
340         netBase <- transform context base
341         let
342             baseAddress = NetAST.address netBase
343             limit = baseAddress + 2^bits - 1
344             netLimit = NetAST.Address limit
345         return NetAST.BlockSpec
346             { NetAST.base  = netBase
347             , NetAST.limit = netLimit
348             }
349
350 instance NetTransformable AST.MapSpec NetAST.MapSpec where
351     transform context ast = do
352         let
353             block = AST.block ast
354             destNode = AST.destNode ast
355             destBase = fromMaybe (AST.base block) (AST.destBase ast)
356         netBlock <- transform context block
357         netDestNode <- transform context destNode
358         netDestBase <- transform context destBase
359         return NetAST.MapSpec
360             { NetAST.srcBlock = netBlock
361             , NetAST.destNode = netDestNode
362             , NetAST.destBase = netDestBase
363             }
364
365 instance NetTransformable AST.Address NetAST.Address where
366     transform _ (AST.LiteralAddress value) = do
367         return $ NetAST.Address value
368     transform context (AST.ParamAddress name) = do
369         let
370             value = getParamValue context name
371         return $ NetAST.Address value
372
373 instance NetTransformable a b => NetTransformable (AST.For a) [b] where
374     transform context ast = do
375         let
376             body = AST.body ast
377             varRanges = AST.varRanges ast
378         concreteRanges <- transform context varRanges
379         let
380             valueList = Map.foldWithKey iterations [] concreteRanges
381             iterContexts = map iterationContext valueList
382             ts = map (\c -> transform c body) iterContexts
383             fs = lefts ts
384             bs = rights ts
385         case fs of
386             [] -> return $ bs
387             _  -> Left $ CheckFailure (concat $ map failures fs)
388         where
389             iterations k vs [] = [Map.fromList [(k,v)] | v <- vs]
390             iterations k vs ms = concat $ map (f ms k) vs
391                 where
392                     f ms k v = map (Map.insert k v) ms
393             iterationContext varMap =
394                 let
395                     values = varValues context
396                 in context
397                     { varValues = values `Map.union` varMap }
398
399 instance NetTransformable AST.ForRange [Word] where
400     transform context ast = do
401         let
402             start = AST.start ast
403             end = AST.end ast
404         startVal <- transform context start
405         endVal <- transform context end
406         return [startVal..endVal]
407
408 instance NetTransformable AST.ForLimit Word where
409     transform _ (AST.LiteralLimit value) = return value
410     transform context (AST.ParamLimit name) = return $ getParamValue context name
411
412 instance NetTransformable a b => NetTransformable [a] [b] where
413     transform context ast = do
414         let
415             ts = map (transform context) ast
416             fs = lefts ts
417             bs = rights ts
418         case fs of
419             [] -> return bs
420             _  -> Left $ CheckFailure (concat $ map failures fs)
421
422 instance (Ord k, NetTransformable a b) => NetTransformable (Map k a) (Map k b) where
423     transform context ast = do
424         let
425             ks = Map.keys ast
426             es = Map.elems ast
427         ts <- transform context es
428         return $ Map.fromList (zip ks ts)
429
430 --
431 -- Checks
432 --
433 class NetCheckable a where
434     check :: Set NetAST.NodeId -> a -> Either CheckFailure ()
435
436 instance NetCheckable NetAST.NetSpec where
437     check context (NetAST.NetSpec net) = do
438         let
439             specContext = Map.keysSet net
440         check specContext $ Map.elems net
441
442 instance NetCheckable NetAST.NodeSpec where
443     check context net = do
444         let
445             translate = NetAST.translate net
446             overlay = NetAST.overlay net
447         check context translate
448         maybe (return ()) (check context) overlay
449
450 instance NetCheckable NetAST.MapSpec where
451     check context net = do
452         let
453            destNode = NetAST.destNode net
454         check context destNode
455
456 instance NetCheckable NetAST.NodeId where
457     check context net = do
458         if net `Set.member` context
459             then return ()
460             else Left $ CheckFailure [UndefinedReference $ show net]
461
462 instance NetCheckable a => NetCheckable [a] where
463     check context net = do
464         let
465             checked = map (check context) net
466             fs = lefts $ checked
467         case fs of
468             [] -> return ()
469             _  -> Left $ CheckFailure (concat $ map failures fs)
470
471 getModule :: Context -> String -> AST.Module
472 getModule context name =
473     let
474         modules = AST.modules $ spec context
475     in modules Map.! name
476
477 getParamValue :: Context -> String -> Word
478 getParamValue context name =
479     let
480         params = paramValues context
481     in params Map.! name
482
483 getVarValue :: Context -> String -> Word
484 getVarValue context name =
485     let
486         vars = varValues context
487     in vars Map.! name
488
489 checkDuplicates :: (Eq a, Show a) => [a] -> (String -> FailedCheck) -> Either CheckFailure ()
490 checkDuplicates nodeIds fail = do
491     let
492         duplicates = duplicateNames nodeIds
493     case duplicates of
494         [] -> return ()
495         _  -> Left $ CheckFailure (map (fail . show) duplicates)
496     where
497         duplicateNames [] = []
498         duplicateNames (x:xs)
499             | x `elem` xs = nub $ [x] ++ duplicateNames xs
500             | otherwise = duplicateNames xs