Sockeye: implement reference check inside node declarations
[barrelfish] / tools / sockeye / SockeyeTypeChecker.hs
1 {-
2     SockeyeChecker.hs: AST checker for Sockeye
3
4     Part of Sockeye
5
6     Copyright (c) 2017, ETH Zurich.
7
8     All rights reserved.
9
10     This file is distributed under the terms in the attached LICENSE file.
11     If you do not find this file, copies can be found by writing to:
12     ETH Zurich D-INFK, CAB F.78, Universitaetstr. 6, CH-8092 Zurich,
13     Attn: Systems Group.
14 -}
15
16 {-# LANGUAGE MultiParamTypeClasses #-}
17 {-# LANGUAGE FlexibleInstances #-}
18 {-# LANGUAGE FlexibleContexts #-}
19
20 module SockeyeTypeChecker
21 ( typeCheckSockeye ) where
22
23 import Control.Monad
24
25 import Data.List (nub)
26 import Data.Map(Map)
27 import qualified Data.Map as Map
28 import Data.Set (Set)
29 import qualified Data.Set as Set
30 import Data.Either
31
32 import SockeyeChecks
33
34 import qualified SockeyeASTParser as ParseAST
35 import qualified SockeyeASTTypeChecker as CheckAST
36
37 import Debug.Trace
38
39 data TypeCheckFail
40     = DuplicateModule String
41     | DuplicateParameter String
42     | DuplicateVariable String
43     | NoSuchModule String
44     | NoSuchParameter String
45     | NoSuchVariable String
46     | ParamTypeMismatch String CheckAST.ModuleParamType CheckAST.ModuleParamType
47     | WrongNumberOfArgs String Int Int
48     | ArgTypeMismatch String String CheckAST.ModuleParamType CheckAST.ModuleParamType
49
50 instance Show TypeCheckFail where
51     show (DuplicateModule name)    = concat ["Multiple definitions for module '", name, "'"]
52     show (DuplicateParameter name) = concat ["Multiple parameters named '", name, "'"]
53     show (DuplicateVariable name)  = concat ["Multiple definitions for variable '", name, "'"]
54     show (NoSuchModule name)       = concat ["No definition for module '", name, "'"]
55     show (NoSuchParameter name)    = concat ["Parameter '", name, "' not in scope"]
56     show (NoSuchVariable name)     = concat ["Variable '", name, "' not in scope"]
57     show (WrongNumberOfArgs name takes given) = concat ["Module '", name, "' takes ", show takes, " argument(s), given ", show given]
58     show (ParamTypeMismatch name expected actual) =
59         concat ["Expected type '", show expected, "' but '", name, "' has type '", show actual, "'"]
60     show (ArgTypeMismatch modName name expected actual) =
61         concat ["Type mismatch for argument '", name, "' for module '", modName, "': Expected '", show expected, "', given '", show actual, "'"]
62
63 data ModuleSymbol = ModuleSymbol
64     { paramNames :: [String]
65     , paramTypes :: Map String CheckAST.ModuleParamType
66     }
67 type SymbolTable = Map String ModuleSymbol
68
69 data Context = Context
70     { symTable   :: SymbolTable
71     , curModule  :: !String
72     , instModule :: !String
73     , vars       :: Set String
74     }
75
76 typeCheckSockeye :: ParseAST.SockeyeSpec -> Either (FailedChecks TypeCheckFail) CheckAST.SockeyeSpec
77 typeCheckSockeye ast = do
78     symbolTable <- runChecks $ buildSymbolTable ast
79     let context = Context
80             { symTable   = symbolTable
81             , curModule  = ""
82             , instModule = ""
83             , vars       = Set.empty
84             }
85     runChecks $ check context ast
86
87 --
88 -- Build Symbol table
89 --
90 class SymbolSource a where
91     buildSymbolTable :: a -> Checks TypeCheckFail SymbolTable
92
93 instance SymbolSource ParseAST.SockeyeSpec where
94     buildSymbolTable ast = do
95         let mods = ParseAST.modules ast
96         symbolTables <- mapM buildSymbolTable mods
97         let names = concat $ map Map.keys symbolTables
98         checkDuplicates "@all" DuplicateModule names
99         return $ Map.unions symbolTables
100         
101 instance SymbolSource ParseAST.Module where
102     buildSymbolTable ast = do
103         let modName = ParseAST.name ast
104             params = ParseAST.parameters ast
105             names = map ParseAST.paramName params
106             types = map ParseAST.paramType params
107         checkDuplicates modName DuplicateParameter names
108         let typeMap = Map.fromList $ zip names types
109             modSymbol = ModuleSymbol
110                 { paramNames = names
111                 , paramTypes = typeMap
112                 }
113         return $ Map.singleton modName modSymbol
114
115 --
116 -- Check module bodies
117 --
118 class Checkable a b where
119     check :: Context -> a -> Checks TypeCheckFail b
120
121 instance Checkable ParseAST.SockeyeSpec CheckAST.SockeyeSpec where
122     check context ast = do
123         let mods = ParseAST.modules ast
124             rootNetSpecs = ParseAST.net ast
125             names = map ParseAST.name mods
126             rootName = "@root"
127             rootSymbol = ModuleSymbol
128                 { paramNames = []
129                 , paramTypes = Map.empty
130                 }
131             rootModContext = context
132                 { symTable = Map.insert rootName rootSymbol $ symTable context
133                 , curModule = rootName
134                 }
135         checkedRootNetSpecs <- check rootModContext rootNetSpecs
136         checkedModules <- check context mods
137         let root = CheckAST.ModuleInst
138                 { CheckAST.namespace  = CheckAST.SimpleIdent ""
139                 , CheckAST.moduleName = rootName
140                 , CheckAST.arguments  = Map.empty
141                 , CheckAST.inPortMap  = []
142                 , CheckAST.outPortMap = []
143                 }
144             rootModule = CheckAST.Module
145                 { CheckAST.paramNames   = []
146                 , CheckAST.paramTypeMap = Map.empty
147                 , CheckAST.ports        = []
148                 , CheckAST.nodeDecls    = lefts  checkedRootNetSpecs
149                 , CheckAST.moduleInsts  = rights checkedRootNetSpecs
150                 }
151             moduleMap = Map.fromList $ zip (rootName:names) (rootModule:checkedModules)
152         return CheckAST.SockeyeSpec
153             { CheckAST.root    = root
154             , CheckAST.modules = moduleMap
155             }
156
157 instance Checkable ParseAST.Module CheckAST.Module where
158     check context ast = do
159         let
160             name = ParseAST.name ast
161             body = ParseAST.moduleBody ast
162             ports = ParseAST.ports body
163             netSpecs = ParseAST.moduleNet body
164             symbol = (symTable context) Map.! name
165         let bodyContext = context
166                 { curModule = name }
167         checkedPorts <- check bodyContext ports
168         checkedNetSpecs <- check bodyContext netSpecs
169         let
170             checkedNodeDecls = lefts checkedNetSpecs
171             checkedModuleInsts = rights checkedNetSpecs
172         return CheckAST.Module
173             { CheckAST.paramNames   = paramNames symbol
174             , CheckAST.paramTypeMap = paramTypes symbol
175             , CheckAST.ports        = checkedPorts
176             , CheckAST.nodeDecls    = checkedNodeDecls
177             , CheckAST.moduleInsts  = checkedModuleInsts
178             }
179
180 instance Checkable ParseAST.Port CheckAST.Port where
181     check context (ParseAST.InputPort portId portWidth) = do
182         checkedId <- check context portId
183         return $ CheckAST.InputPort checkedId portWidth
184     check context (ParseAST.OutputPort portId portWidth) = do
185         checkedId <- check context portId
186         return $ CheckAST.OutputPort checkedId portWidth
187     check context (ParseAST.MultiPort for) = do
188         checkedFor <- check context for
189         return $ CheckAST.MultiPort checkedFor
190
191 instance Checkable ParseAST.NetSpec (Either CheckAST.NodeDecl CheckAST.ModuleInst) where
192     check context (ParseAST.NodeDeclSpec decl) = do
193         checkedDecl <- check context decl
194         return $ Left checkedDecl
195     check context (ParseAST.ModuleInstSpec inst) = do
196         checkedInst <- check context inst
197         return $ Right checkedInst
198
199 instance Checkable ParseAST.ModuleInst CheckAST.ModuleInst where
200     check context (ParseAST.MultiModuleInst for) = do
201         checkedFor <- check context for
202         return $ CheckAST.MultiModuleInst checkedFor
203     check context ast = do
204         let
205             namespace = ParseAST.namespace ast
206             name = ParseAST.moduleName ast
207             arguments = ParseAST.arguments ast
208             portMaps = ParseAST.portMappings ast
209         checkedArgs <- if Map.member name (symTable context)
210             then check (context { instModule = name }) arguments
211             else do
212                 failCheck (curModule context) $ NoSuchModule name
213                 return Map.empty
214         checkedNamespace <- check context namespace
215         inPortMap  <- check context $ filter isInMap  portMaps
216         outPortMap <- check context $ filter isOutMap portMaps
217         return CheckAST.ModuleInst
218             { CheckAST.namespace  = checkedNamespace
219             , CheckAST.moduleName = name
220             , CheckAST.arguments  = checkedArgs
221             , CheckAST.inPortMap  = inPortMap
222             , CheckAST.outPortMap = outPortMap
223             }
224         where
225             isInMap  (ParseAST.InputPortMap  {}) = True
226             isInMap  (ParseAST.OutputPortMap {}) = False
227             isInMap  (ParseAST.MultiPortMap for) = isInMap $ ParseAST.body for
228             isOutMap (ParseAST.InputPortMap  {}) = False
229             isOutMap (ParseAST.OutputPortMap {}) = True
230             isOutMap (ParseAST.MultiPortMap for) = isOutMap $ ParseAST.body for
231
232 instance Checkable [ParseAST.ModuleArg] (Map String CheckAST.ModuleArg) where
233     check context ast = do
234         let symbol = (symTable context) Map.! instName
235             names = paramNames symbol
236             expTypes = map (paramTypes symbol Map.!) names
237         checkArgCount names ast
238         checkedArgs <- zipWithM checkArgType (zip names expTypes) ast
239         return $ Map.fromList $ zip names checkedArgs
240         where
241             checkArgCount params args = do
242                 let
243                     paramc = length params
244                     argc = length args
245                 if argc == paramc
246                     then return ()
247                     else failCheck curName $ WrongNumberOfArgs instName paramc argc
248             checkArgType (name, expType) arg = do
249                 case arg of
250                     ParseAST.AddressArg value -> do
251                         if expType == CheckAST.AddressParam
252                             then return $ CheckAST.AddressArg value
253                             else do
254                                 mismatch CheckAST.AddressParam
255                                 return $ CheckAST.AddressArg value
256                     ParseAST.NaturalArg value -> do
257                         if expType == CheckAST.NaturalParam
258                             then return $ CheckAST.NaturalArg value
259                             else do
260                                 mismatch CheckAST.NaturalParam
261                                 return $ CheckAST.AddressArg value
262                     ParseAST.ParamArg paramName -> do
263                         checkParamType context paramName expType
264                         return $ CheckAST.ParamArg paramName
265                 where
266                     mismatch = failCheck curName . ArgTypeMismatch instName name expType
267             curName = curModule context
268             instName = instModule context
269
270 instance Checkable ParseAST.PortMap CheckAST.PortMap where
271     check context (ParseAST.MultiPortMap for) = do
272         checkedFor <- check context for
273         return $ CheckAST.MultiPortMap checkedFor
274     check context portMap = do
275         let
276             mappedId = ParseAST.mappedId portMap
277             mappedPort = ParseAST.mappedPort portMap
278         checkedId <- check context mappedId
279         checkedPort <- check context mappedPort
280         return $ CheckAST.PortMap
281             { CheckAST.mappedId   = checkedId
282             , CheckAST.mappedPort = checkedPort
283             }
284
285 instance Checkable ParseAST.NodeDecl CheckAST.NodeDecl where
286     check context (ParseAST.MultiNodeDecl for) = do
287         checkedFor <- check context for
288         return $ CheckAST.MultiNodeDecl checkedFor
289     check context ast = do
290         let
291             nodeId = ParseAST.nodeId ast
292             nodeSpec = ParseAST.nodeSpec ast
293         checkedId <- check context nodeId
294         checkedSpec <- check context nodeSpec
295         return CheckAST.NodeDecl
296             { CheckAST.nodeId   = checkedId
297             , CheckAST.nodeSpec = checkedSpec
298             }
299
300 instance Checkable ParseAST.Identifier CheckAST.Identifier where
301     check _ (ParseAST.SimpleIdent name) = return $ CheckAST.SimpleIdent name
302     check context ast = do
303         let
304             prefix = ParseAST.prefix ast
305             varName = ParseAST.varName ast
306             suffix = ParseAST.suffix ast
307         checkVarInScope context varName
308         checkedSuffix <- case suffix of
309             Nothing    -> return Nothing
310             Just ident -> do
311                 checkedIdent <- check context ident
312                 return $ Just checkedIdent
313         return CheckAST.TemplateIdent
314             { CheckAST.prefix  = prefix
315             , CheckAST.varName = varName
316             , CheckAST.suffix  = checkedSuffix
317             }
318
319 instance Checkable ParseAST.NodeSpec CheckAST.NodeSpec where
320     check context ast = do
321         let 
322             nodeType = ParseAST.nodeType ast
323             accept = ParseAST.accept ast
324             translate = ParseAST.translate ast
325             overlay = ParseAST.overlay ast
326             reserved = ParseAST.reserved ast
327         checkedAccept <- check context accept
328         checkedTranslate <- check context translate
329         checkedReserved <- check context reserved
330         checkedOverlay <- case overlay of
331             Nothing    -> return Nothing
332             Just ident -> do
333                 checkedIdent <- check context ident
334                 return $ Just checkedIdent
335         return CheckAST.NodeSpec
336             { CheckAST.nodeType  = nodeType
337             , CheckAST.accept    = checkedAccept
338             , CheckAST.translate = checkedTranslate
339             , CheckAST.reserved  = checkedReserved
340             , CheckAST.overlay   = checkedOverlay
341             }
342
343 instance Checkable ParseAST.BlockSpec CheckAST.BlockSpec where
344     check context (ParseAST.SingletonBlock address) = do
345         checkedAddress <- check context address
346         return CheckAST.SingletonBlock
347             { CheckAST.base = checkedAddress }
348     check context (ParseAST.RangeBlock base limit) = do
349         checkedBase <- check context base
350         checkedLimit <- check context limit
351         return CheckAST.RangeBlock
352             { CheckAST.base  = checkedBase
353             , CheckAST.limit = checkedLimit
354             }
355     check context (ParseAST.LengthBlock base bits) = do
356         checkedBase <- check context base
357         return CheckAST.LengthBlock
358             { CheckAST.base = checkedBase
359             , CheckAST.bits = bits
360             }
361
362 instance Checkable ParseAST.MapSpec CheckAST.MapSpec where
363     check context ast = do
364         let
365             block = ParseAST.block ast
366             destNode = ParseAST.destNode ast
367             destBase = ParseAST.destBase ast
368         checkedBlock <- check context block
369         checkedDestNode <- check context destNode
370         checkedDestBase <- case destBase of
371             Nothing      -> return Nothing
372             Just address -> do
373                 checkedAddress <- check context address
374                 return $ Just checkedAddress
375         return CheckAST.MapSpec
376             { CheckAST.block    = checkedBlock
377             , CheckAST.destNode = checkedDestNode
378             , CheckAST.destBase = checkedDestBase
379             }
380
381 instance Checkable ParseAST.OverlaySpec CheckAST.OverlaySpec where
382     check context (ParseAST.OverlaySpec over width) = do
383         checkedOver <- check context over
384         return $ CheckAST.OverlaySpec checkedOver width
385
386 instance Checkable ParseAST.Address CheckAST.Address where
387     check _ (ParseAST.LiteralAddress value) = do
388         return $ CheckAST.LiteralAddress value
389     check context (ParseAST.ParamAddress name) = do
390         checkParamType context name CheckAST.AddressParam
391         return $ CheckAST.ParamAddress name
392
393 instance Checkable a b => Checkable (ParseAST.For a) (CheckAST.For b) where
394     check context ast = do
395         let
396             varRanges = ParseAST.varRanges ast
397             varNames = map ParseAST.var varRanges
398             body = ParseAST.body ast
399             currentVars = vars context
400         checkDuplicates (curModule context) DuplicateVariable (varNames ++ Set.elems currentVars)
401         ranges <- check context varRanges
402         let
403             bodyVars = currentVars `Set.union` (Set.fromList varNames)
404             bodyContext = context
405                 { vars = bodyVars }
406         checkedBody <- check bodyContext body
407         let
408             checkedVarRanges = Map.fromList $ zip varNames ranges
409         return CheckAST.For
410                 { CheckAST.varRanges = checkedVarRanges
411                 , CheckAST.body      = checkedBody
412                 }
413
414 instance Checkable ParseAST.ForVarRange CheckAST.ForRange where
415     check context ast = do
416         let 
417             start = ParseAST.start ast
418             end = ParseAST.end ast
419         checkedStart <- check context start
420         checkedEnd<- check context end
421         return CheckAST.ForRange
422             { CheckAST.start = checkedStart
423             , CheckAST.end   = checkedEnd
424             }
425
426 instance Checkable ParseAST.ForLimit CheckAST.ForLimit where
427     check _ (ParseAST.LiteralLimit value) = do
428         return $ CheckAST.LiteralLimit value
429     check context (ParseAST.ParamLimit name) = do
430         checkParamType context name CheckAST.NaturalParam
431         return $ CheckAST.ParamLimit name
432
433 instance (Traversable t, Checkable a b) => Checkable (t a) (t b) where
434     check context as = mapM (check context) as
435
436 --
437 -- Helpers
438 --    
439 checkVarInScope :: Context -> String -> Checks TypeCheckFail ()
440 checkVarInScope context name = do
441     if name `Set.member` (vars context)
442         then return ()
443         else failCheck (curModule context) $ NoSuchVariable name
444
445
446 checkParamType :: Context -> String -> CheckAST.ModuleParamType -> Checks TypeCheckFail ()
447 checkParamType context name expected = do
448     let symbol = (symTable context) Map.! (curModule context)
449     case Map.lookup name $ paramTypes symbol of
450         Nothing -> failCheck (curModule context) $ NoSuchParameter name
451         Just actual -> do
452             if actual == expected
453                 then return ()
454                 else failCheck (curModule context) $ ParamTypeMismatch name expected actual