Sockeye: Finish checker for NodeSpec
[barrelfish] / tools / sockeye / SockeyeChecker.hs
1 {-
2     SockeyeChecker.hs: AST checker for Sockeye
3
4     Part of Sockeye
5
6     Copyright (c) 2017, ETH Zurich.
7
8     All rights reserved.
9
10     This file is distributed under the terms in the attached LICENSE file.
11     If you do not find this file, copies can be found by writing to:
12     ETH Zurich D-INFK, CAB F.78, Universitaetstr. 6, CH-8092 Zurich,
13     Attn: Systems Group.
14 -}
15
16 {-# LANGUAGE MultiParamTypeClasses #-}
17 {-# LANGUAGE FlexibleContexts #-}
18
19 module SockeyeChecker
20 ( checkSockeye ) where
21
22 import Control.Monad (join)
23
24 import Data.List (nub)
25 import Data.Map (Map)
26 import qualified Data.Map as Map
27 import Data.Set (Set)
28 import qualified Data.Set as Set
29 import Data.Either
30
31 import qualified SockeyeASTFrontend as ASTF
32 import qualified SockeyeASTIntermediate as ASTI
33
34 import Debug.Trace
35
36 data FailedCheck
37     = DuplicateModule String
38     | DuplicateParameter String
39     | DuplicateVariable String
40     | NoSuchModule String
41     | NoSuchParameter String
42     | NoSuchVariable String
43     | ParameterTypeMismatch String ASTI.ModuleParamType ASTI.ModuleParamType
44
45 instance Show FailedCheck where
46     show (DuplicateModule name)    = concat ["Multiple definitions for module '", name, "'."]
47     show (DuplicateParameter name) = concat ["Multiple definitions for parameter '", name, "'."]
48     show (DuplicateVariable name)  = concat ["Multiple definitions for variable '", name, "'."]
49     show (NoSuchModule name)       = concat ["No definition for module '", name, "'."]
50     show (NoSuchParameter name)    = concat ["Parameter '", name, "' not in scope."]
51     show (NoSuchVariable name)     = concat ["Variable '", name, "' not in scope."]
52     show (ParameterTypeMismatch name expected actual) =
53         concat ["Parameter '", name, "' of type '", show actual, "' used where type '", show expected, "' is required."]
54
55 newtype CheckFailure = CheckFailure
56     { failedChecks :: [FailedCheck] }
57
58 instance Show CheckFailure where
59     show (CheckFailure fs) = unlines $ map (("    " ++) . show) fs
60
61 data Context = Context
62     { spec       :: ASTI.SockeyeSpec
63     , moduleName :: !String
64     , vars       :: Set String
65     }
66
67 checkSockeye :: ASTF.SockeyeSpec -> Either CheckFailure ASTI.SockeyeSpec
68 checkSockeye ast = do
69     symbolTable <- buildSymbolTable ast
70     let
71         context = Context
72             { spec       = symbolTable
73             , moduleName = ""
74             , vars       = Set.empty
75             }
76     check context ast
77 -- build symbol table
78 -- check modules:
79 --  - parameter types must match usage site types
80 --  - all variables must exist
81 --  - 
82 --  - all instantiated modules must exist
83 --  - modules can not instantiate themselves
84 --  - instantiation argument types must match parameter types
85
86 --
87 -- Build Symbol table
88 --
89 class SymbolSource a b where
90     buildSymbolTable :: a -> Either CheckFailure b
91
92 instance SymbolSource ASTF.SockeyeSpec ASTI.SockeyeSpec where
93     buildSymbolTable ast = do
94         let
95             modules = (rootModule ast):(ASTF.modules ast)
96             names = map ASTF.name modules
97         checkDuplicates names DuplicateModule
98         symbolTables <- forAll buildSymbolTable modules
99         let
100             moduleMap = Map.fromList $ zip names symbolTables
101         return ASTI.SockeyeSpec
102                 { ASTI.modules = moduleMap }
103
104 instance SymbolSource ASTF.Module ASTI.Module where
105     buildSymbolTable ast = do
106         let
107             paramNames = map ASTF.paramName (ASTF.parameters ast)
108             paramTypes = map ASTF.paramType (ASTF.parameters ast)
109         checkDuplicates paramNames DuplicateParameter
110         let
111             paramTypeMap = Map.fromList $ zip paramNames paramTypes
112         return ASTI.Module
113             { ASTI.paramNames  = paramNames
114             , ASTI.paramTypes  = paramTypeMap
115             , ASTI.inputPorts  = []
116             , ASTI.outputPorts = []
117             , ASTI.nodeDecls   = []
118             , ASTI.moduleInsts = []
119             }
120 --
121 -- Check module bodies
122 --
123 class Checkable a b where
124     check :: Context -> a -> Either CheckFailure b
125
126 instance Checkable ASTF.SockeyeSpec ASTI.SockeyeSpec where
127     check context ast = do
128         let
129             modules = (rootModule ast):(ASTF.modules ast)
130             names = map ASTF.name modules
131         checked <- forAll (check context) modules
132         let
133             sockeyeSpec = spec context
134             checkedMap = Map.fromList $ zip names checked
135         return sockeyeSpec
136             { ASTI.modules = checkedMap }
137
138 instance Checkable ASTF.Module ASTI.Module where
139     check context ast = do
140         let
141             name = ASTF.name ast
142             bodyContext = context
143                 { moduleName = name}
144             body = ASTF.moduleBody ast
145             portDefs = ASTF.ports body
146             netSpecs = ASTF.moduleNet body
147         inputPorts  <- forAll (check bodyContext) $ filter isInPort  portDefs
148         outputPorts <- forAll (check bodyContext) $ filter isOutPort portDefs
149         checkedNetSpecs <- forAll (checkNetSpec bodyContext) netSpecs
150         let
151             checkedNodeDecls = lefts checkedNetSpecs
152             checkedModuleInsts = rights checkedNetSpecs
153         mod <- getCurrentModule bodyContext
154         return mod
155             { ASTI.inputPorts  = inputPorts
156             , ASTI.outputPorts = outputPorts
157             , ASTI.nodeDecls   = checkedNodeDecls
158             , ASTI.moduleInsts = checkedModuleInsts
159             }
160         where
161             isInPort (ASTF.InputPortDef _) = True
162             isInPort (ASTF.MultiPortDef for) = isInPort $ ASTF.body for
163             isInPort _ = False
164             isOutPort = not . isInPort
165             checkNetSpec context (ASTF.NodeDeclSpec decl) = do
166                 checkedDecl <- check context decl
167                 return $ Left checkedDecl
168             checkNetSpec context (ASTF.ModuleInstSpec inst) = do
169                 checkedInst <- check context inst
170                 return $ Right checkedInst
171
172 instance Checkable ASTF.PortDef ASTI.Port where
173     check context (ASTF.MultiPortDef for) = do
174         checkedFor <- check context for
175         return $ ASTI.MultiPort checkedFor
176     check context portDef = do
177         checkedId <- check context (ASTF.portId portDef)
178         return $ ASTI.Port checkedId
179
180 instance Checkable ASTF.ModuleInst ASTI.ModuleInst where
181     check context (ASTF.MultiModuleInst for) = do
182         checkedFor <- check context for
183         return $ ASTI.MultiModuleInst checkedFor
184     check context ast = do
185         let
186             nameSpace = ASTF.nameSpace ast
187             name = ASTF.moduleName ast
188             arguments = ASTF.arguments ast
189             portMaps = ASTF.portMappings ast
190         checkedNameSpace <- check context nameSpace
191         mod <- getModule context name
192         return ASTI.ModuleInst
193             { ASTI.nameSpace     = checkedNameSpace
194             , ASTI.moduleName    = name
195             , ASTI.arguments     = Map.empty
196             , ASTI.inputPortMap  = []
197             , ASTI.outputPortMap = []
198             }
199
200 instance Checkable ASTF.NodeDecl ASTI.NodeDecl where
201     check context (ASTF.MultiNodeDecl for) = do
202         checkedFor <- check context for
203         return $ ASTI.MultiNodeDecl checkedFor
204     check context ast = do
205         let
206             nodeId = ASTF.nodeId ast
207             nodeSpec = ASTF.nodeSpec ast
208         checkedId <- check context nodeId
209         checkedSpec <- check context nodeSpec
210         return ASTI.NodeDecl
211             { ASTI.nodeId   = checkedId
212             , ASTI.nodeSpec = checkedSpec
213             }
214
215 instance Checkable ASTF.Identifier ASTI.Identifier where
216     check _ (ASTF.SimpleIdent name) = return $ ASTI.SimpleIdent name
217     check context ast = do
218         let
219             prefix = ASTF.prefix ast
220             varName = ASTF.varName ast
221             suffix = ASTF.suffix ast
222         checkVarInScope context varName
223         checkedSuffix <- case suffix of
224             Nothing    -> return Nothing
225             Just ident -> do
226                 checkedIdent <- check context ident
227                 return $ Just checkedIdent
228         return ASTI.TemplateIdent
229             { ASTI.prefix  = prefix
230             , ASTI.varName = varName
231             , ASTI.suffix  = checkedSuffix
232             }
233
234 instance Checkable ASTF.NodeSpec ASTI.NodeSpec where
235     check context ast = do
236         let 
237             nodeType = ASTF.nodeType ast
238             accept = ASTF.accept ast
239             translate = ASTF.translate ast
240             overlay = ASTF.overlay ast
241         checkedAccept <- forAll (check context) accept
242         checkedTranslate <- forAll (check context) translate
243         checkedOverlay <- case overlay of
244             Nothing    -> return Nothing
245             Just ident -> do
246                 checkedIdent <- check context ident
247                 return $ Just checkedIdent
248         return ASTI.NodeSpec
249             { ASTI.nodeType  = nodeType
250             , ASTI.accept    = checkedAccept
251             , ASTI.translate = checkedTranslate
252             , ASTI.overlay   = checkedOverlay
253             }
254
255 instance Checkable ASTF.BlockSpec ASTI.BlockSpec where
256     check context (ASTF.SingletonBlock address) = do
257         checkedAddress <- check context address
258         return ASTI.SingletonBlock
259             { ASTI.address = checkedAddress }
260     check context (ASTF.RangeBlock base limit) = do
261         let
262             addresses = [base, limit]
263         checkedAddresses <- forAll (check context) addresses
264         return ASTI.RangeBlock
265             { ASTI.base  = head checkedAddresses
266             , ASTI.limit = last checkedAddresses
267             }
268     check context (ASTF.LengthBlock base bits) = do
269         checkedBase <- check context base
270         return ASTI.LengthBlock
271             { ASTI.base = checkedBase
272             , ASTI.bits = bits
273             }
274
275 instance Checkable ASTF.MapSpec ASTI.MapSpec where
276     check context ast = do
277         let
278             block = ASTF.block ast
279             destNode = ASTF.destNode ast
280             destBase = ASTF.destBase ast
281         checkedBlock <- check context block
282         checkedDestNode <- check context destNode
283         checkedDestBase <- case destBase of
284             Nothing      -> return Nothing
285             Just address -> do
286                 checkedAddress <- check context address
287                 return $ Just checkedAddress
288         return ASTI.MapSpec
289             { ASTI.block    = checkedBlock
290             , ASTI.destNode = checkedDestNode
291             , ASTI.destBase = checkedDestBase
292             }
293
294 instance Checkable ASTF.Address ASTI.Address where
295     check _ (ASTF.NumberAddress value) = do
296         return $ ASTI.NumberAddress value
297     check context (ASTF.ParamAddress name) = do
298         checkParamType context name ASTI.AddressParam
299         return $ ASTI.ParamAddress name
300
301 instance Checkable a b => Checkable (ASTF.For a) (ASTI.For b) where
302     check context ast = do
303         let
304             varNames = map ASTF.var (ASTF.varRanges ast)
305         checkDuplicates varNames DuplicateVariable
306         ranges <- forAll (check context) (ASTF.varRanges ast)
307         let
308             currentVars = vars context
309             bodyVars = currentVars `Set.union` (Set.fromList varNames)
310             bodyContext = context
311                 { vars = bodyVars }
312         body <- check bodyContext $ ASTF.body ast
313         let
314             varRanges = Map.fromList $ zip varNames ranges
315         return ASTI.For
316                 { ASTI.varRanges = varRanges
317                 , ASTI.body      = body
318                 }
319
320 instance Checkable ASTF.ForVarRange ASTI.ForRange where
321     check context ast = do
322         let
323             limits = [ASTF.start ast, ASTF.end ast]
324         checkedLimits <- forAll (check context) limits
325         return ASTI.ForRange
326             { ASTI.start = head checkedLimits
327             , ASTI.end   = last checkedLimits
328             }
329
330 instance Checkable ASTF.ForLimit ASTI.ForLimit where
331     check _ (ASTF.NumberLimit value) = do
332         return $ ASTI.NumberLimit value
333     check context (ASTF.ParamLimit name) = do
334         checkParamType context name ASTI.NumberParam
335         return $ ASTI.ParamLimit name
336 --
337 -- Helpers
338 --
339 rootModule :: ASTF.SockeyeSpec -> ASTF.Module
340 rootModule spec =
341     let
342         body = ASTF.ModuleBody
343             { ASTF.ports = []
344             , ASTF.moduleNet = ASTF.net spec
345             }
346     in ASTF.Module
347         { ASTF.name       = "@root"
348         , ASTF.parameters = []
349         , ASTF.moduleBody = body
350         }
351
352 getCurrentModule :: Context -> Either CheckFailure ASTI.Module
353 getCurrentModule context = do
354     let
355         modMap = ASTI.modules $ spec context
356     return $ modMap Map.! (moduleName context)
357
358 getModule :: Context -> String -> Either CheckFailure ASTI.Module
359 getModule context name = do
360     let
361         modMap = ASTI.modules $ spec context
362     case Map.lookup name modMap of
363         Nothing -> Left $ CheckFailure [NoSuchModule name]
364         Just m  -> return m
365
366 getParameterType :: Context -> String -> Either CheckFailure ASTI.ModuleParamType
367 getParameterType context name = do
368     mod <- getCurrentModule context
369     let
370         paramMap = ASTI.paramTypes mod
371     case Map.lookup name paramMap of
372         Nothing -> Left $ CheckFailure [NoSuchParameter name]
373         Just t  -> return t
374
375 forAll :: (a -> Either CheckFailure b) -> [a] -> Either CheckFailure [b]
376 forAll f as = do
377     let
378         bs = map f as
379         es = concat $ map failedChecks (lefts bs)
380     case es of
381         [] -> return $ rights bs
382         _  -> Left $ CheckFailure es
383
384 checkDuplicates :: [String] -> (String -> FailedCheck) -> Either CheckFailure ()
385 checkDuplicates names failure = do
386     let
387         duplicates = duplicateNames names
388     case duplicates of
389         [] -> return ()
390         _  -> Left $ CheckFailure (map failure duplicates)
391     where
392         duplicateNames [] = []
393         duplicateNames (x:xs)
394             | x `elem` xs = nub $ [x] ++ duplicateNames xs
395             | otherwise = duplicateNames xs
396
397 checkVarInScope :: Context -> String -> Either CheckFailure ()
398 checkVarInScope context name = do
399     if name `Set.member` (vars context)
400         then return ()
401         else Left $ CheckFailure [NoSuchVariable name]
402
403
404 checkParamType :: Context -> String -> ASTI.ModuleParamType -> Either CheckFailure ()
405 checkParamType context name expected = do
406     actual <- getParameterType context name
407     if actual == expected
408         then return ()
409         else Left $ mismatch actual
410     where
411         mismatch t = CheckFailure [ParameterTypeMismatch name expected t]