Sockeye: Checker now checks everything
[barrelfish] / tools / sockeye / SockeyeChecker.hs
1 {-
2     SockeyeChecker.hs: AST checker for Sockeye
3
4     Part of Sockeye
5
6     Copyright (c) 2017, ETH Zurich.
7
8     All rights reserved.
9
10     This file is distributed under the terms in the attached LICENSE file.
11     If you do not find this file, copies can be found by writing to:
12     ETH Zurich D-INFK, CAB F.78, Universitaetstr. 6, CH-8092 Zurich,
13     Attn: Systems Group.
14 -}
15
16 {-# LANGUAGE MultiParamTypeClasses #-}
17 {-# LANGUAGE FlexibleContexts #-}
18
19 module SockeyeChecker
20 ( checkSockeye ) where
21
22 import Control.Monad (join)
23
24 import Data.List (nub)
25 import Data.Map (Map)
26 import qualified Data.Map as Map
27 import Data.Set (Set)
28 import qualified Data.Set as Set
29 import Data.Either
30
31 import qualified SockeyeASTFrontend as ASTF
32 import qualified SockeyeASTIntermediate as ASTI
33
34 import Debug.Trace
35
36 data FailedCheck
37     = DuplicateModule String
38     | DuplicateParameter String
39     | DuplicateVariable String
40     | NoSuchModule String
41     | NoSuchParameter String
42     | NoSuchVariable String
43     | ParamTypeMismatch String ASTI.ModuleParamType ASTI.ModuleParamType
44     | WrongNumberOfArgs String Int Int
45     | ArgTypeMismatch String String ASTI.ModuleParamType ASTI.ModuleParamType
46
47 instance Show FailedCheck where
48     show (DuplicateModule name)    = concat ["Multiple definitions for module '", name, "'."]
49     show (DuplicateParameter name) = concat ["Multiple definitions for parameter '", name, "'."]
50     show (DuplicateVariable name)  = concat ["Multiple definitions for variable '", name, "'."]
51     show (NoSuchModule name)       = concat ["No definition for module '", name, "'."]
52     show (NoSuchParameter name)    = concat ["Parameter '", name, "' not in scope."]
53     show (NoSuchVariable name)     = concat ["Variable '", name, "' not in scope."]
54     show (ParamTypeMismatch name expected actual) =
55         concat ["Parameter '", name, "' of type '", show actual, "' used as '", show expected, "'."]
56     show (WrongNumberOfArgs name has given) =
57         concat ["Module '", name, "' takes ", show has, " arguments, given ", show given, "."]
58     show (ArgTypeMismatch modName paramName expected actual) =
59         concat ["Argument '", paramName, "' to module '", modName, "' of type '", show expected, "' instantiated with type '", show actual, "."]
60
61 newtype CheckFailure = CheckFailure
62     { failedChecks :: [FailedCheck] }
63
64 instance Show CheckFailure where
65     show (CheckFailure fs) = unlines $ map (("    " ++) . show) fs
66
67 data Context = Context
68     { spec       :: ASTI.SockeyeSpec
69     , moduleName :: !String
70     , vars       :: Set String
71     }
72
73 checkSockeye :: ASTF.SockeyeSpec -> Either CheckFailure ASTI.SockeyeSpec
74 checkSockeye ast = do
75     symbolTable <- buildSymbolTable ast
76     let
77         context = Context
78             { spec       = symbolTable
79             , moduleName = ""
80             , vars       = Set.empty
81             }
82     check context ast
83 -- build symbol table
84 -- check modules:
85 --  - parameter types must match usage site types
86 --  - all variables must exist
87 --  - 
88 --  - all instantiated modules must exist
89 --  - modules can not instantiate themselves
90 --  - instantiation argument types must match parameter types
91
92 --
93 -- Build Symbol table
94 --
95 class SymbolSource a b where
96     buildSymbolTable :: a -> Either CheckFailure b
97
98 instance SymbolSource ASTF.SockeyeSpec ASTI.SockeyeSpec where
99     buildSymbolTable ast = do
100         let
101             modules = (rootModule ast):(ASTF.modules ast)
102             names = map ASTF.name modules
103         checkDuplicates names DuplicateModule
104         symbolTables <- forAll buildSymbolTable modules
105         let
106             moduleMap = Map.fromList $ zip names symbolTables
107         return ASTI.SockeyeSpec
108                 { ASTI.modules = moduleMap }
109
110 instance SymbolSource ASTF.Module ASTI.Module where
111     buildSymbolTable ast = do
112         let
113             paramNames = map ASTF.paramName (ASTF.parameters ast)
114             paramTypes = map ASTF.paramType (ASTF.parameters ast)
115         checkDuplicates paramNames DuplicateParameter
116         let
117             paramTypeMap = Map.fromList $ zip paramNames paramTypes
118         return ASTI.Module
119             { ASTI.paramNames   = paramNames
120             , ASTI.paramTypeMap = paramTypeMap
121             , ASTI.inputPorts   = []
122             , ASTI.outputPorts  = []
123             , ASTI.nodeDecls    = []
124             , ASTI.moduleInsts  = []
125             }
126 --
127 -- Check module bodies
128 --
129 class Checkable a b where
130     check :: Context -> a -> Either CheckFailure b
131
132 instance Checkable ASTF.SockeyeSpec ASTI.SockeyeSpec where
133     check context ast = do
134         let
135             modules = (rootModule ast):(ASTF.modules ast)
136             names = map ASTF.name modules
137         checked <- forAll (check context) modules
138         let
139             sockeyeSpec = spec context
140             checkedMap = Map.fromList $ zip names checked
141         return sockeyeSpec
142             { ASTI.modules = checkedMap }
143
144 instance Checkable ASTF.Module ASTI.Module where
145     check context ast = do
146         let
147             name = ASTF.name ast
148             bodyContext = context
149                 { moduleName = name}
150             body = ASTF.moduleBody ast
151             portDefs = ASTF.ports body
152             netSpecs = ASTF.moduleNet body
153         inputPorts  <- forAll (check bodyContext) $ filter isInPort  portDefs
154         outputPorts <- forAll (check bodyContext) $ filter isOutPort portDefs
155         checkedNetSpecs <- forAll (checkNetSpec bodyContext) netSpecs
156         let
157             checkedNodeDecls = lefts checkedNetSpecs
158             checkedModuleInsts = rights checkedNetSpecs
159             mod = getCurrentModule bodyContext
160         return mod
161             { ASTI.inputPorts  = inputPorts
162             , ASTI.outputPorts = outputPorts
163             , ASTI.nodeDecls   = checkedNodeDecls
164             , ASTI.moduleInsts = checkedModuleInsts
165             }
166         where
167             isInPort (ASTF.InputPortDef _) = True
168             isInPort (ASTF.MultiPortDef for) = isInPort $ ASTF.body for
169             isInPort _ = False
170             isOutPort = not . isInPort
171             checkNetSpec context (ASTF.NodeDeclSpec decl) = do
172                 checkedDecl <- check context decl
173                 return $ Left checkedDecl
174             checkNetSpec context (ASTF.ModuleInstSpec inst) = do
175                 checkedInst <- check context inst
176                 return $ Right checkedInst
177
178 instance Checkable ASTF.PortDef ASTI.Port where
179     check context (ASTF.MultiPortDef for) = do
180         checkedFor <- check context for
181         return $ ASTI.MultiPort checkedFor
182     check context portDef = do
183         checkedId <- check context (ASTF.portId portDef)
184         return $ ASTI.Port checkedId
185
186 instance Checkable ASTF.ModuleInst ASTI.ModuleInst where
187     check context (ASTF.MultiModuleInst for) = do
188         checkedFor <- check context for
189         return $ ASTI.MultiModuleInst checkedFor
190     check context ast = do
191         let
192             nameSpace = ASTF.nameSpace ast
193             name = ASTF.moduleName ast
194             arguments = ASTF.arguments ast
195             portMaps = ASTF.portMappings ast
196         mod <- getModule context name
197         checkedNameSpace <- check context nameSpace
198         checkArgCount name mod arguments
199         checkedArgs <- checkArgTypes name mod arguments 
200         inPortMap  <- forAll (check context) $ filter isInMap  portMaps
201         outPortMap <- forAll (check context) $ filter isOutMap portMaps
202         return ASTI.ModuleInst
203             { ASTI.nameSpace  = checkedNameSpace
204             , ASTI.moduleName = name
205             , ASTI.arguments  = checkedArgs
206             , ASTI.inPortMap  = inPortMap
207             , ASTI.outPortMap = outPortMap
208             }
209         where
210             isInMap (ASTF.InputPortMap {}) = True
211             isInMap (ASTF.MultiPortMap for) = isInMap $ ASTF.body for
212             isInMap _ = False
213             isOutMap = not . isInMap
214             checkArgCount modName mod args = do
215                 let
216                     paramc = length $ ASTI.paramNames mod
217                     argc = length args
218                 if argc == paramc
219                     then return ()
220                     else Left $ CheckFailure [WrongNumberOfArgs modName paramc argc]
221             checkArgTypes modName mod args = do
222                 let
223                     paramNames = ASTI.paramNames mod
224                 checkedArgs <- forAll id $ zipWith (checkArgType modName mod) args paramNames
225                 return $ Map.fromList $ zip paramNames checkedArgs
226             checkArgType modName mod arg paramName = do
227                 let
228                     expected = getParameterType mod paramName
229                 case arg of
230                     ASTF.AddressArg value -> do
231                         if expected == ASTI.AddressParam
232                             then return $ ASTI.AddressArg value
233                             else Left $ mismatch expected ASTI.AddressParam
234                     ASTF.NumberArg value -> do
235                         if expected == ASTI.NumberParam
236                             then return $ ASTI.NumberArg value
237                             else Left $ mismatch expected ASTI.NumberParam
238                     ASTF.ParamArg name -> do
239                         checkParamType context name expected
240                         return $ ASTI.ParamArg name
241                 where
242                     mismatch expected actual = CheckFailure [ArgTypeMismatch modName paramName expected actual]
243
244 instance Checkable ASTF.PortMap ASTI.PortMap where
245     check context (ASTF.MultiPortMap for) = do
246         checkedFor <- check context for
247         return $ ASTI.MultiPortMap checkedFor
248     check context portMap = do
249         let
250             mappedId = ASTF.mappedId portMap
251             mappedPort = ASTF.mappedPort portMap
252             idents = [mappedId, mappedPort]
253         checkedIds <- forAll (check context) idents
254         return $ ASTI.PortMap
255             { ASTI.mappedId   = head checkedIds
256             , ASTI.mappedPort = last checkedIds
257             }
258
259 instance Checkable ASTF.NodeDecl ASTI.NodeDecl where
260     check context (ASTF.MultiNodeDecl for) = do
261         checkedFor <- check context for
262         return $ ASTI.MultiNodeDecl checkedFor
263     check context ast = do
264         let
265             nodeId = ASTF.nodeId ast
266             nodeSpec = ASTF.nodeSpec ast
267         checkedId <- check context nodeId
268         checkedSpec <- check context nodeSpec
269         return ASTI.NodeDecl
270             { ASTI.nodeId   = checkedId
271             , ASTI.nodeSpec = checkedSpec
272             }
273
274 instance Checkable ASTF.Identifier ASTI.Identifier where
275     check _ (ASTF.SimpleIdent name) = return $ ASTI.SimpleIdent name
276     check context ast = do
277         let
278             prefix = ASTF.prefix ast
279             varName = ASTF.varName ast
280             suffix = ASTF.suffix ast
281         checkVarInScope context varName
282         checkedSuffix <- case suffix of
283             Nothing    -> return Nothing
284             Just ident -> do
285                 checkedIdent <- check context ident
286                 return $ Just checkedIdent
287         return ASTI.TemplateIdent
288             { ASTI.prefix  = prefix
289             , ASTI.varName = varName
290             , ASTI.suffix  = checkedSuffix
291             }
292
293 instance Checkable ASTF.NodeSpec ASTI.NodeSpec where
294     check context ast = do
295         let 
296             nodeType = ASTF.nodeType ast
297             accept = ASTF.accept ast
298             translate = ASTF.translate ast
299             overlay = ASTF.overlay ast
300         checkedAccept <- forAll (check context) accept
301         checkedTranslate <- forAll (check context) translate
302         checkedOverlay <- case overlay of
303             Nothing    -> return Nothing
304             Just ident -> do
305                 checkedIdent <- check context ident
306                 return $ Just checkedIdent
307         return ASTI.NodeSpec
308             { ASTI.nodeType  = nodeType
309             , ASTI.accept    = checkedAccept
310             , ASTI.translate = checkedTranslate
311             , ASTI.overlay   = checkedOverlay
312             }
313
314 instance Checkable ASTF.BlockSpec ASTI.BlockSpec where
315     check context (ASTF.SingletonBlock address) = do
316         checkedAddress <- check context address
317         return ASTI.SingletonBlock
318             { ASTI.address = checkedAddress }
319     check context (ASTF.RangeBlock base limit) = do
320         let
321             addresses = [base, limit]
322         checkedAddresses <- forAll (check context) addresses
323         return ASTI.RangeBlock
324             { ASTI.base  = head checkedAddresses
325             , ASTI.limit = last checkedAddresses
326             }
327     check context (ASTF.LengthBlock base bits) = do
328         checkedBase <- check context base
329         return ASTI.LengthBlock
330             { ASTI.base = checkedBase
331             , ASTI.bits = bits
332             }
333
334 instance Checkable ASTF.MapSpec ASTI.MapSpec where
335     check context ast = do
336         let
337             block = ASTF.block ast
338             destNode = ASTF.destNode ast
339             destBase = ASTF.destBase ast
340         checkedBlock <- check context block
341         checkedDestNode <- check context destNode
342         checkedDestBase <- case destBase of
343             Nothing      -> return Nothing
344             Just address -> do
345                 checkedAddress <- check context address
346                 return $ Just checkedAddress
347         return ASTI.MapSpec
348             { ASTI.block    = checkedBlock
349             , ASTI.destNode = checkedDestNode
350             , ASTI.destBase = checkedDestBase
351             }
352
353 instance Checkable ASTF.Address ASTI.Address where
354     check _ (ASTF.NumberAddress value) = do
355         return $ ASTI.NumberAddress value
356     check context (ASTF.ParamAddress name) = do
357         checkParamType context name ASTI.AddressParam
358         return $ ASTI.ParamAddress name
359
360 instance Checkable a b => Checkable (ASTF.For a) (ASTI.For b) where
361     check context ast = do
362         let
363             varNames = map ASTF.var (ASTF.varRanges ast)
364         checkDuplicates varNames DuplicateVariable
365         ranges <- forAll (check context) (ASTF.varRanges ast)
366         let
367             currentVars = vars context
368             bodyVars = currentVars `Set.union` (Set.fromList varNames)
369             bodyContext = context
370                 { vars = bodyVars }
371         body <- check bodyContext $ ASTF.body ast
372         let
373             varRanges = Map.fromList $ zip varNames ranges
374         return ASTI.For
375                 { ASTI.varRanges = varRanges
376                 , ASTI.body      = body
377                 }
378
379 instance Checkable ASTF.ForVarRange ASTI.ForRange where
380     check context ast = do
381         let
382             limits = [ASTF.start ast, ASTF.end ast]
383         checkedLimits <- forAll (check context) limits
384         return ASTI.ForRange
385             { ASTI.start = head checkedLimits
386             , ASTI.end   = last checkedLimits
387             }
388
389 instance Checkable ASTF.ForLimit ASTI.ForLimit where
390     check _ (ASTF.NumberLimit value) = do
391         return $ ASTI.NumberLimit value
392     check context (ASTF.ParamLimit name) = do
393         checkParamType context name ASTI.NumberParam
394         return $ ASTI.ParamLimit name
395 --
396 -- Helpers
397 --
398 rootModule :: ASTF.SockeyeSpec -> ASTF.Module
399 rootModule spec =
400     let
401         body = ASTF.ModuleBody
402             { ASTF.ports = []
403             , ASTF.moduleNet = ASTF.net spec
404             }
405     in ASTF.Module
406         { ASTF.name       = "@root"
407         , ASTF.parameters = []
408         , ASTF.moduleBody = body
409         }
410
411 getModule :: Context -> String -> Either CheckFailure ASTI.Module
412 getModule context name = do
413     let
414         modMap = ASTI.modules $ spec context
415     case Map.lookup name modMap of
416         Nothing -> Left $ CheckFailure [NoSuchModule name]
417         Just m  -> return m
418
419 getCurrentModule :: Context -> ASTI.Module
420 getCurrentModule context =
421     let
422         modMap = ASTI.modules $ spec context
423     in modMap Map.! (moduleName context)
424
425 getParameterType :: ASTI.Module -> String -> ASTI.ModuleParamType
426 getParameterType mod name =
427     let
428         paramMap = ASTI.paramTypeMap mod
429     in paramMap Map.! name
430
431 getCurrentParameterType :: Context -> String -> Either CheckFailure ASTI.ModuleParamType
432 getCurrentParameterType context name = do
433     let
434         mod = getCurrentModule context
435         paramMap = ASTI.paramTypeMap mod
436     case Map.lookup name paramMap of
437         Nothing -> Left $ CheckFailure [NoSuchParameter name]
438         Just t  -> return t
439
440 forAll :: (a -> Either CheckFailure b) -> [a] -> Either CheckFailure [b]
441 forAll f as = do
442     let
443         bs = map f as
444         es = concat $ map failedChecks (lefts bs)
445     case es of
446         [] -> return $ rights bs
447         _  -> Left $ CheckFailure es
448
449 checkDuplicates :: [String] -> (String -> FailedCheck) -> Either CheckFailure ()
450 checkDuplicates names failure = do
451     let
452         duplicates = duplicateNames names
453     case duplicates of
454         [] -> return ()
455         _  -> Left $ CheckFailure (map failure duplicates)
456     where
457         duplicateNames [] = []
458         duplicateNames (x:xs)
459             | x `elem` xs = nub $ [x] ++ duplicateNames xs
460             | otherwise = duplicateNames xs
461
462 checkVarInScope :: Context -> String -> Either CheckFailure ()
463 checkVarInScope context name = do
464     if name `Set.member` (vars context)
465         then return ()
466         else Left $ CheckFailure [NoSuchVariable name]
467
468
469 checkParamType :: Context -> String -> ASTI.ModuleParamType -> Either CheckFailure ()
470 checkParamType context name expected = do
471     actual <- getCurrentParameterType context name
472     if actual == expected
473         then return ()
474         else Left $ mismatch actual
475     where
476         mismatch t = CheckFailure [ParamTypeMismatch name expected t]