Sockeye: Implement reference check in input port declarations
[barrelfish] / tools / sockeye / SockeyeNetBuilder.hs
1 {-
2     SockeyeNetBuilder.hs: Decoding net builder for Sockeye
3
4     Part of Sockeye
5
6     Copyright (c) 2017, ETH Zurich.
7
8     All rights reserved.
9
10     This file is distributed under the terms in the attached LICENSE file.
11     If you do not find this file, copies can be found by writing to:
12     ETH Zurich D-INFK, CAB F.78, Universitaetstr. 6, CH-8092 Zurich,
13     Attn: Systems Group.
14 -}
15
16 {-# LANGUAGE MultiParamTypeClasses #-}
17 {-# LANGUAGE FlexibleInstances #-}
18 {-# LANGUAGE FlexibleContexts #-}
19
20 module SockeyeNetBuilder
21 ( buildSockeyeNet ) where
22
23 import Control.Monad.State
24
25 import Data.Either
26 import Data.List (nub, intercalate, sort)
27 import Data.Map (Map)
28 import qualified Data.Map as Map
29 import Data.Maybe (fromMaybe)
30 import Data.Set (Set)
31 import qualified Data.Set as Set
32
33 import SockeyeChecks
34
35 import qualified SockeyeASTInstantiator as InstAST
36 import qualified SockeyeASTDecodingNet as NetAST
37
38 import Debug.Trace
39
40 data NetBuildFail
41     = UndefinedOutPort   !String !String
42     | UndefinedInPort    !String !String
43     | UndefinedReference !String !String
44
45 instance Show NetBuildFail where
46     show (UndefinedInPort  inst port)  = concat ["Mapping to undefined input port '",   port, "' in module instantiation '", inst, "'"]
47     show (UndefinedOutPort inst port)  = concat ["Mapping to undefined output port '",  port, "' in module instantiation '", inst, "'"]
48     show (UndefinedReference context ident) = concat ["Reference to undefined node '", ident, "' in ", context]
49
50 type PortMap = Map InstAST.Identifier NetAST.NodeId
51
52 data Context = Context
53     { modules      :: Map InstAST.Identifier InstAST.Module
54     , curModule    :: !String
55     , curNamespace :: [String]
56     , curNode      :: !String
57     , inPortMap    :: PortMap
58     , outPortMap   :: PortMap
59     , nodes        :: Set String
60     , mappedBlocks :: [InstAST.BlockSpec]
61     }
62
63 buildSockeyeNet :: InstAST.SockeyeSpec -> Either (FailedChecks NetBuildFail) NetAST.NetSpec
64 buildSockeyeNet ast = do
65     let
66         context = Context
67             { modules      = Map.empty
68             , curModule    = ""
69             , curNamespace = []
70             , curNode      = ""
71             , inPortMap    = Map.empty
72             , outPortMap   = Map.empty
73             , nodes        = Set.empty
74             , mappedBlocks = []
75             }        
76     net <- runChecks $ transform context ast
77     return net
78
79 --            
80 -- Build net
81 --
82 class NetTransformable a b where
83     transform :: Context -> a -> Checks NetBuildFail b
84
85 instance NetTransformable InstAST.SockeyeSpec NetAST.NetSpec where
86     transform context ast = do
87         let
88             rootInst = InstAST.root ast
89             mods = InstAST.modules ast
90             specContext = context
91                 { modules = mods }
92         transform specContext rootInst
93
94 instance NetTransformable InstAST.Module NetAST.NetSpec where
95     transform context ast = do
96         let inPorts = InstAST.inputPorts ast
97             outPorts = InstAST.outputPorts ast
98             moduleInsts = InstAST.moduleInsts ast
99             nodeDecls = InstAST.nodeDecls ast
100             outPortIds = map InstAST.portId outPorts
101             inMapIds = concatMap Map.elems $ map InstAST.inPortMap moduleInsts
102             declIds = map InstAST.nodeId nodeDecls
103             modContext = context
104                 { nodes = Set.fromList $ outPortIds ++ inMapIds ++ declIds }
105         -- TODO: check mapping to undefined port
106         inPortDecls <- transform modContext inPorts
107         outPortDecls <- transform modContext outPorts
108         netDecls <- transform modContext nodeDecls
109         netInsts <- transform modContext moduleInsts     
110         return $ Map.unions (inPortDecls ++ outPortDecls ++ netDecls ++ netInsts)
111
112 instance NetTransformable InstAST.Port NetAST.NetSpec where
113     transform context ast@(InstAST.InputPort {}) = do
114         let portId = InstAST.portId ast
115             portWidth = InstAST.portWidth ast
116             portMap = inPortMap context
117             mappedId = Map.lookup portId portMap
118             errorContext = "input port declaration"
119         checkReference context (UndefinedReference errorContext) portId
120         netPortId <- transform context portId
121         case mappedId of
122             Nothing    -> return Map.empty
123             Just ident -> do
124                 let node = portNode netPortId portWidth
125                 return $ Map.fromList [(ident, node)]
126     transform context ast@(InstAST.OutputPort {}) = do
127         let portId = InstAST.portId ast
128             portWidth = InstAST.portWidth ast
129             portMap = outPortMap context
130             mappedId = Map.lookup portId portMap
131         netPortId <- transform context portId
132         case mappedId of
133             Nothing    -> return $ Map.fromList [(netPortId, portNodeTemplate)]
134             Just ident -> do
135                 let node = portNode ident portWidth
136                 return $ Map.fromList $ [(netPortId, node)]
137
138 portNode :: NetAST.NodeId -> Integer -> NetAST.NodeSpec
139 portNode destNode width =
140     let base = 0
141         limit = 2^width - 1
142         srcBlock = NetAST.BlockSpec
143             { NetAST.base  = base
144             , NetAST.limit = limit
145             }
146         map = NetAST.MapSpec
147                 { NetAST.srcBlock = srcBlock
148                 , NetAST.destNode = destNode
149                 , NetAST.destBase = base
150                 }
151     in portNodeTemplate { NetAST.translate = [map] }
152
153 portNodeTemplate :: NetAST.NodeSpec
154 portNodeTemplate = NetAST.NodeSpec
155     { NetAST.nodeType  = NetAST.Other
156     , NetAST.accept    = []
157     , NetAST.translate = []
158     }
159
160 instance NetTransformable InstAST.ModuleInst NetAST.NetSpec where
161     transform context ast = do
162         let name = InstAST.moduleName ast
163             namespace = InstAST.namespace ast
164             inPortMap = InstAST.inPortMap ast
165             outPortMap = InstAST.outPortMap ast
166             mod = (modules context) Map.! name
167             errorContext = concat ["port mapping for '", name, " as ", namespace, "'"]
168         mapM_ (checkReference context $ UndefinedReference errorContext) $ (Map.elems inPortMap) ++ (Map.elems outPortMap)
169         netInMap <- transform context inPortMap
170         netOutMap <- transform context outPortMap
171         let instContext = context
172                 { curModule    = name
173                 , curNamespace = namespace:(curNamespace context)
174                 , inPortMap    = netInMap
175                 , outPortMap   = netOutMap
176                 }
177         transform instContext mod
178
179 instance NetTransformable InstAST.NodeDecl NetAST.NetSpec where
180     transform context ast = do
181         let nodeId = InstAST.nodeId ast
182             nodeSpec = InstAST.nodeSpec ast
183             nodeContext = context
184                 { curNode = nodeId }
185         netNodeId <- transform context nodeId
186         netNodeSpec <- transform nodeContext nodeSpec
187         return $ Map.fromList [(netNodeId, netNodeSpec)]
188
189 instance NetTransformable InstAST.Identifier NetAST.NodeId where
190     transform context ast = do
191         let namespace = curNamespace context
192         return NetAST.NodeId
193             { NetAST.namespace = namespace
194             , NetAST.name      = ast
195             }
196
197 instance NetTransformable InstAST.NodeSpec NetAST.NodeSpec where
198     transform context ast = do
199         let
200             nodeType = InstAST.nodeType ast
201             accept = InstAST.accept ast
202             translate = InstAST.translate ast
203             reserved = InstAST.reserved ast
204             overlay = InstAST.overlay ast
205         netTranslate <- transform context translate
206         let
207             mapBlocks = map NetAST.srcBlock netTranslate
208             nodeContext = context
209                 { mappedBlocks = accept ++ mapBlocks ++ reserved }
210         netOverlay <- case overlay of
211                 Nothing -> return []
212                 Just o  -> transform nodeContext o
213         return NetAST.NodeSpec
214             { NetAST.nodeType  = nodeType
215             , NetAST.accept    = accept
216             , NetAST.translate = netTranslate ++ netOverlay
217             }
218
219 instance NetTransformable InstAST.MapSpec NetAST.MapSpec where
220     transform context ast = do
221         let
222             srcBlock = InstAST.srcBlock ast
223             destNode = InstAST.destNode ast
224             destBase = InstAST.destBase ast
225             errorContext = "tranlate set of node '" ++ curNode context ++ "'"
226         checkReference context (UndefinedReference errorContext) destNode
227         netDestNode <- transform context destNode
228         return NetAST.MapSpec
229             { NetAST.srcBlock = srcBlock
230             , NetAST.destNode = netDestNode
231             , NetAST.destBase = destBase
232             }
233
234 instance NetTransformable InstAST.OverlaySpec [NetAST.MapSpec] where
235     transform context ast = do
236         let
237             over = InstAST.over ast
238             width = InstAST.width ast
239             blocks = mappedBlocks context
240             errorContext = "overlay of node '" ++ curNode context ++ "'"
241         checkReference context (UndefinedReference errorContext) over
242         netOver <- transform context over
243         let maps = overlayMaps netOver width blocks
244         return maps
245
246 overlayMaps :: NetAST.NodeId -> Integer -> [NetAST.BlockSpec] -> [NetAST.MapSpec]
247 overlayMaps destId width blocks =
248     let
249         blockPoints = concat $ map toScanPoints blocks
250         maxAddress = 2^width
251         overStop  = BlockStart $ maxAddress
252         scanPoints = filter ((maxAddress >=) . address) $ sort (overStop:blockPoints)
253         startState = ScanLineState
254             { insideBlocks    = 0
255             , startAddress    = 0
256             }
257     in evalState (scanLine scanPoints []) startState
258     where
259         toScanPoints (NetAST.BlockSpec base limit) =
260                 [ BlockStart base
261                 , BlockEnd   limit
262                 ]
263         scanLine [] ms = return ms
264         scanLine (p:ps) ms = do
265             maps <- pointAction p ms
266             scanLine ps maps
267         pointAction (BlockStart a) ms = do
268             s <- get       
269             let
270                 i = insideBlocks s
271                 base = startAddress s
272                 limit = a - 1
273             maps <- if (i == 0) && (base <= limit)
274                 then
275                     let
276                         baseAddress = startAddress s
277                         limitAddress = a - 1
278                         srcBlock = NetAST.BlockSpec baseAddress limitAddress
279                         m = NetAST.MapSpec srcBlock destId baseAddress
280                     in return $ m:ms
281                 else return ms
282             modify (\s -> s { insideBlocks = i + 1})
283             return maps
284         pointAction (BlockEnd a) ms = do
285             s <- get
286             let
287                 i = insideBlocks s
288             put $ ScanLineState (i - 1) (a + 1)
289             return ms
290
291 data StoppingPoint
292     = BlockStart { address :: !NetAST.Address }
293     | BlockEnd   { address :: !NetAST.Address }
294     deriving (Eq, Show)
295
296 instance Ord StoppingPoint where
297     (<=) (BlockStart a1) (BlockEnd   a2)
298         | a1 == a2 = True
299         | otherwise = a1 <= a2
300     (<=) (BlockEnd   a1) (BlockStart a2)
301         | a1 == a2 = False
302         | otherwise = a1 <= a2
303     (<=) sp1 sp2 = (address sp1) <= (address sp2)
304
305 data ScanLineState
306     = ScanLineState
307         { insideBlocks :: !Integer
308         , startAddress :: !NetAST.Address
309         } deriving (Show)
310
311 instance (Traversable t, NetTransformable a b) => NetTransformable (t a)  (t b) where
312     transform context as = mapM (transform context) as
313
314 checkReference :: Context -> (String -> NetBuildFail) -> String -> (Checks NetBuildFail) ()
315 checkReference context fail name =
316     if name `Set.member` (nodes context)
317         then return ()
318         else failCheck (curModule context) (fail name)