Skip to content

Commit ea21ee5

Browse files
committed
Working on typechecker.
1 parent 7dccb9f commit ea21ee5

File tree

6 files changed

+154
-43
lines changed

6 files changed

+154
-43
lines changed

lambdaQ.cabal

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ source-repository head
2626
library
2727
exposed-modules:
2828
Backend.CodeGenerator
29-
Common
3029
CompilationEngine
3130
Frontend.ASTtoIASTConverter
3231
Frontend.LambdaQ.Abs

src/Common.hs

Lines changed: 0 additions & 5 deletions
This file was deleted.

src/CompilationEngine.hs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import Control.Monad.Except
1717
)
1818
import Control.Exception (try)
1919

20-
import Common ( ErrorMessage )
2120
import Frontend.ASTtoIASTConverter (Program, runAstToIastConverter)
2221
import Frontend.SemanticAnalyser (runSemanticAnalyser)
2322
import Frontend.TypeChecker (runTypeChecker)
@@ -27,11 +26,11 @@ import Frontend.LambdaQ.Layout ( resolveLayout )
2726
import qualified Frontend.LambdaQ.Abs as GeneratedAbstractSyntax
2827

2928
data CompilationError =
30-
ParseError ErrorMessage |
31-
SemanticError ErrorMessage |
32-
SyntaxTreeConversionError ErrorMessage |
33-
TypeCheckError ErrorMessage |
34-
CodeGenerationError ErrorMessage |
29+
ParseError String |
30+
SemanticError String |
31+
SyntaxTreeConversionError String |
32+
TypeCheckError String |
33+
CodeGenerationError String |
3534
FileDoesNotExist FilePath
3635

3736
instance Show CompilationError where

src/Frontend/ASTtoIASTConverter.hs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ module Frontend.ASTtoIASTConverter (
1616
mapProgram,
1717
Program,
1818
runAstToIastConverter,
19+
simplifyTensorProd,
1920
Term(..),
2021
Type(..),
2122
) where
@@ -159,8 +160,8 @@ data Term =
159160
TermFreeVariable String |
160161
TermList List |
161162
TermListElement List Integer |
162-
TermBoolExpression BoolExpression |
163-
TermIntegerExpression IntegerExpression |
163+
TermBool BoolExpression |
164+
TermInteger IntegerExpression |
164165
TermVariable Var |
165166
TermIfElse Term Term Term |
166167
TermLet Term Term |
@@ -436,8 +437,8 @@ mapTerm env (GeneratedAbstractSyntax.TermListElement l index) = TermListElement
436437
mapTerm _ (GeneratedAbstractSyntax.TermBit (GeneratedAbstractSyntax.Bit bit)) = if bit == "0b0" then TermBit BitZero else TermBit BitOne
437438
mapTerm _ GeneratedAbstractSyntax.TermUnit = TermUnit
438439
mapTerm _ (GeneratedAbstractSyntax.TermBasisState bs) = TermBasisState (mapBasisState bs)
439-
mapTerm _ (GeneratedAbstractSyntax.TermBoolExpression be) = TermBoolExpression (mapBoolExpression be)
440-
mapTerm _ (GeneratedAbstractSyntax.TermIntegerExpression be) = TermIntegerExpression (mapIntegerExpression be)
440+
mapTerm _ (GeneratedAbstractSyntax.TermBoolExpression be) = TermBool (mapBoolExpression be)
441+
mapTerm _ (GeneratedAbstractSyntax.TermIntegerExpression be) = TermInteger (mapIntegerExpression be)
441442
mapTerm _ (GeneratedAbstractSyntax.TermGate g) = TermGate (mapGate g)
442443
mapTerm env (GeneratedAbstractSyntax.TermTuple term terms) = TermTuple (mapTerm env term) (map (mapTerm env) (term:terms))
443444
mapTerm env (GeneratedAbstractSyntax.TermApply l r) = TermApply (mapTerm env l) (mapTerm env r)

src/Frontend/TypeChecker.hs

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,23 @@ where
1919
--import Data.Map (Map, lookup)
2020
--import Data.Set (Set, member, insert)
2121

22-
import qualified Common
23-
import qualified Data.Map
2422
import qualified Control.Monad.Except
2523
import qualified Control.Monad.Reader
2624
import qualified Control.Monad.State
2725
import qualified Control.Monad.State.Class
26+
import qualified Data.Map
27+
import qualified Data.Maybe
2828
import qualified Data.Set
2929

30-
import Frontend.ASTtoIASTConverter (Function(..), Gate(..), Program, Term(..), Type(..))
30+
import Frontend.ASTtoIASTConverter (Function(..), Gate(..), Program, Term(..), Type(..), simplifyTensorProd)
3131

3232
data TypeError
3333
= NotAFunction Type (Int, Int, String) -- this type should be a function but it is not
3434
| FunctionNotInScope String (Int, Int, String) -- this variable denotes a function which is not in scope at the point where it is used
3535
| TypeMismatch Type Type (Int, Int, String) -- this type does not match the type expected at the point where it was declared
3636
| NotAProductType Type (Int, Int, String) -- this type should be a product type but it is not
3737
| DuplicatedLinearVariable String (Int, Int, String) -- this linear variable is used more than once
38-
| NotALinearFunction String (Int, Int, String) -- this function is used more than once despite being declared linear
38+
| NotALinearFunction String (Int, Int, String) -- this function is used more than once despite not being declared linear
3939
| NotALinearTerm Term Type (Int, Int, String) -- this term should be linear but is is not
4040
| NoCommonSupertype Type Type (Int, Int, String) -- these two types have no common supertype
4141
deriving (Eq, Ord, Read)
@@ -54,8 +54,8 @@ instance Show TypeError where
5454

5555
show (DuplicatedLinearVariable var (line, _, fname)) = "The linear variable '" ++ var ++ "' in the top level function named: '" ++ fname ++ "' defined at line: " ++ show line ++ " is used more than once"
5656

57-
show (NotALinearFunction fun (line, _, fname)) = "The function named: '" ++ show fun ++ "' which is used in the top level function named: '" ++ fname
58-
++ ". defined at line: " ++ show line ++ " is used more than once despite being declared linear"
57+
show (NotALinearFunction fun (line, _, fname)) = "The function named: '" ++ fun ++ "' which is used in the top level function named: '" ++ fname
58+
++ "' defined at line: " ++ show line ++ " is used more than once despite not being declared linear"
5959

6060
show (NotALinearTerm term typ (line, _, fname)) = "Term: '" ++ show term ++ "' having as type: " ++ show typ
6161
++ " which occurs in function " ++ fname ++ " defined at line: " ++ show line ++ " is not linear"
@@ -75,7 +75,7 @@ data ErrorEnvironment = ErrorEnvironment
7575

7676
type Check = Control.Monad.Except.ExceptT TypeError (Control.Monad.Reader.ReaderT MainEnvironment (Control.Monad.State.State ErrorEnvironment))
7777

78-
runTypeChecker :: Program -> Either Common.ErrorMessage Program
78+
runTypeChecker :: Program -> Either String Program
7979
runTypeChecker program =
8080
case Control.Monad.State.evalState (Control.Monad.Reader.runReaderT (Control.Monad.Except.runExceptT (typeCheckProgram program)) mainEnv) errorEnv of
8181
Left err -> Left (show err)
@@ -97,17 +97,23 @@ typeCheckFunction :: Function -> Check ()
9797
typeCheckFunction (Function functionName (line, col) functionType term) = do
9898
Control.Monad.State.Class.modify $ \x -> x {currentFunction = functionName}
9999
inferredType <- inferType [] term (line, col, functionName)
100-
if isSubtype inferredType functionType
100+
if isSubtype inferredType functionType
101101
then return ()
102102
else Control.Monad.Except.throwError (TypeMismatch functionType inferredType (line, col, functionName))
103103

104+
-- typesMatch :: Type -> Type -> Bool
105+
-- typesMatch tl tr = (tl' == tr') || isSubtype tl tr
106+
-- where
107+
-- tl' = removeBangs $ pullOutBangs tl
108+
-- tr' = removeBangs $ pullOutBangs tr
109+
104110
isSubtype :: Type -> Type -> Bool
111+
isSubtype (TypeNonLinear t1) (TypeNonLinear t2) = isSubtype (TypeNonLinear t1) t2
105112
isSubtype (TypeNonLinear t1) t2 = isSubtype t1 t2
106-
isSubtype _ (TypeNonLinear _) = False
107-
isSubtype (TypeList t1) (TypeList t2) = isSubtype t1 t2
108113
isSubtype (t1 :->: t2) (t1' :->: t2') = isSubtype t1' t1 && isSubtype t2 t2'
109114
isSubtype (t1 :*: t2) (t1' :*: t2') = isSubtype t1 t1' && isSubtype t2 t2'
110115
isSubtype (t1 :**: n1) (t2 :**: n2) = n1 == n2 && isSubtype t1 t2
116+
isSubtype (TypeList t1) (TypeList t2) = isSubtype t1 t2
111117
isSubtype t1 t2 = t1 == t2
112118

113119
inferType :: [Type] -> Term -> (Int, Int, String) -> Check Type
@@ -119,9 +125,21 @@ inferType _ (TermPower _) _ = return $ TypeNonLinear (TypeQbits :->: TypeQbits)
119125
inferType _ (TermInverse _) _ = return $ TypeNonLinear (TypeQbits :->: TypeQbits)
120126
inferType _ (TermBit _) _ = return $ TypeNonLinear TypeBit
121127
inferType _ (TermGate gate) _ = return $ inferGateType gate
128+
inferType _ (TermBool _) _ = return $ TypeNonLinear TypeBool
129+
inferType _ (TermInteger _) _ = return $ TypeNonLinear TypeInteger
122130
inferType _ TermUnit _ = return $ TypeNonLinear TypeUnit
123131
inferType _ (TermBasisState _) _ = return TypeBasisState
124132

133+
inferType context (TermLambda typ term) (line, col, fname) = do
134+
mainEnv <- Control.Monad.Reader.ask
135+
checkLinearExpression term typ (line, col, fname)
136+
termTyp <- inferType (typ:context) term (line, col, fname)
137+
let boundedLinearVars = any (isLinear . (context !!) . fromIntegral) (freeVariables (TermLambda typ term))
138+
let freeLinearVars = any isLinear $ Data.Maybe.mapMaybe (`Data.Map.lookup` mainEnv) (extractFunctionNames term)
139+
if boundedLinearVars || freeLinearVars
140+
then return (typ :->: termTyp)
141+
else return $ TypeNonLinear (typ :->: termTyp)
142+
125143
inferType context (TermApply termLeft termRight) (line, col, fname) = do
126144
leftTermType <- inferType context termLeft (line, col, fname)
127145
rightTermType <- inferType context termRight (line, col, fname)
@@ -131,6 +149,16 @@ inferType context (TermApply termLeft termRight) (line, col, fname) = do
131149
| otherwise -> Control.Monad.Except.throwError $ TypeMismatch argsType rightTermType (line, col, fname)
132150
_ -> Control.Monad.Except.throwError $ NotAFunction leftTermType (line, col, fname)
133151

152+
inferType context (TermTuple l [r]) (line, col, fname) = do
153+
leftTyp <- inferType context l (line, col, fname)
154+
rightTyp <- inferType context r (line, col, fname)
155+
return $ simplifyTensorProd $ pullOutBangs (leftTyp :*: rightTyp)
156+
157+
inferType context (TermTuple l (r:rs)) (line, col, fname) = do
158+
leftTyp <- inferType context l (line, col, fname)
159+
rightTyp <- inferType context (TermTuple r rs) (line, col, fname)
160+
return $ simplifyTensorProd $ pullOutBangs (leftTyp :*: rightTyp)
161+
134162
inferType _ (TermFreeVariable var) (line, col, fname) = do
135163
mainEnv <- Control.Monad.Reader.ask
136164
linearEnv <- Control.Monad.State.gets linearEnvironment
@@ -172,4 +200,60 @@ isLinear _ = True
172200

173201
removeBangs :: Type -> Type
174202
removeBangs (TypeNonLinear t) = removeBangs t
175-
removeBangs t = t
203+
removeBangs t = t
204+
205+
pullOutBangs :: Type -> Type
206+
pullOutBangs (TypeNonLinear l :*: TypeNonLinear r) = TypeNonLinear (pullOutBangs (l :*: r))
207+
pullOutBangs (TypeNonLinear t :**: n) = TypeNonLinear (pullOutBangs (t :**: n))
208+
pullOutBangs t = t
209+
210+
checkLinearExpression :: Term -> Type -> (Int, Int, String) -> Check ()
211+
checkLinearExpression term typ (line, col, fname) = case typ of
212+
TypeNonLinear _ -> return ()
213+
t -> if headBoundVariableCount term <= 1
214+
then return ()
215+
else Control.Monad.Except.throwError $ NotALinearTerm term t (line, col, fname)
216+
217+
headBoundVariableCount :: Term -> Integer
218+
headBoundVariableCount = headBoundVariableCount' 0
219+
where
220+
headBoundVariableCount' :: Integer -> Term -> Integer
221+
headBoundVariableCount' cnt term = case term of
222+
TermBoundVariable i -> if cnt == i then 1 else 0
223+
TermLambda _ lambdaTerm -> headBoundVariableCount' (cnt + 1) lambdaTerm
224+
TermApply termLeft termRight -> headBoundVariableCount' cnt termLeft + headBoundVariableCount' cnt termRight
225+
TermCompose termLeft termRight -> headBoundVariableCount' cnt termLeft + headBoundVariableCount' cnt termRight
226+
TermDollar termLeft termRight -> headBoundVariableCount' cnt termLeft + headBoundVariableCount' cnt termRight
227+
TermIfElse cond t f -> headBoundVariableCount' cnt cond + max (headBoundVariableCount' cnt t) (headBoundVariableCount' cnt f)
228+
--TermTuple left right -> headBoundVariableCount' cnt left + headBoundVariableCount' cnt right
229+
-- TermList ListNil -> 0
230+
-- TermLetSingle termEq termIn -> headBoundVariableCount' cnt termEq + headBoundVariableCount' (cnt + 1) termIn --TODO: verify
231+
-- TermLetSugarSingle termEq termIn -> headBoundVariableCount' cnt termEq + headBoundVariableCount' (cnt + 1) termIn --TODO: verify
232+
-- TermLetMultiple termEq termIn -> headBoundVariableCount' cnt termEq + headBoundVariableCount' (cnt + 2) termIn
233+
-- TermLetSugarMultiple termEq termIn -> headBoundVariableCount' cnt termEq + headBoundVariableCount' (cnt + 2) termIn
234+
_ -> 0
235+
236+
freeVariables :: Term -> [Integer]
237+
freeVariables = freeVariables' 0
238+
where
239+
freeVariables' :: Integer -> Term -> [Integer]
240+
--freeVariables' cnt (TermTuple left right) = freeVariables' cnt left ++ freeVariables' cnt right
241+
freeVariables' cnt (TermApply termLeft termRight) = freeVariables' cnt termLeft ++ freeVariables' cnt termRight
242+
-- freeVariables' cnt (TermLetSingle termEq termIn) = freeVariables' cnt termEq ++ freeVariables' (cnt + 1) termIn --TODO: verify
243+
-- freeVariables' cnt (TermLetSugarSingle termEq termIn) = freeVariables' cnt termEq ++ freeVariables' (cnt + 1) termIn --TODO: verify
244+
-- freeVariables' cnt (TermLetMultiple termEq termIn) = freeVariables' cnt termEq ++ freeVariables' (cnt + 2) termIn
245+
-- freeVariables' cnt (TermLetSugarMultiple termEq termIn) = freeVariables' cnt termEq ++ freeVariables' (cnt + 2) termIn
246+
freeVariables' cnt (TermLambda _ lambdaTerm) = freeVariables' (cnt + 1) lambdaTerm
247+
freeVariables' cnt (TermBoundVariable i) = [i - cnt | i >= cnt]
248+
freeVariables' _ _ = []
249+
250+
extractFunctionNames :: Term -> [String]
251+
--extractFunctionNames (TermTuple left right) = extractFunctionNames left ++ extractFunctionNames right
252+
extractFunctionNames (TermApply termLeft termRight) = extractFunctionNames termLeft ++ extractFunctionNames termRight
253+
-- extractFunctionNames (TermLetSingle termEq termIn) = extractFunctionNames termEq ++ extractFunctionNames termIn
254+
-- extractFunctionNames (TermLetSugarSingle termEq termIn) = extractFunctionNames termEq ++ extractFunctionNames termIn
255+
-- extractFunctionNames (TermLetMultiple termEq termIn) = extractFunctionNames termEq ++ extractFunctionNames termIn
256+
-- extractFunctionNames (TermLetSugarMultiple termEq termIn) = extractFunctionNames termEq ++ extractFunctionNames termIn
257+
extractFunctionNames (TermLambda _ lambdaTerm) = extractFunctionNames lambdaTerm
258+
extractFunctionNames (TermFreeVariable fun) = [fun]
259+
extractFunctionNames _ = []

0 commit comments

Comments
 (0)