@@ -19,23 +19,23 @@ where
19
19
-- import Data.Map (Map, lookup)
20
20
-- import Data.Set (Set, member, insert)
21
21
22
- import qualified Common
23
- import qualified Data.Map
24
22
import qualified Control.Monad.Except
25
23
import qualified Control.Monad.Reader
26
24
import qualified Control.Monad.State
27
25
import qualified Control.Monad.State.Class
26
+ import qualified Data.Map
27
+ import qualified Data.Maybe
28
28
import qualified Data.Set
29
29
30
- import Frontend.ASTtoIASTConverter (Function (.. ), Gate (.. ), Program , Term (.. ), Type (.. ))
30
+ import Frontend.ASTtoIASTConverter (Function (.. ), Gate (.. ), Program , Term (.. ), Type (.. ), simplifyTensorProd )
31
31
32
32
data TypeError
33
33
= NotAFunction Type (Int , Int , String ) -- this type should be a function but it is not
34
34
| FunctionNotInScope String (Int , Int , String ) -- this variable denotes a function which is not in scope at the point where it is used
35
35
| TypeMismatch Type Type (Int , Int , String ) -- this type does not match the type expected at the point where it was declared
36
36
| NotAProductType Type (Int , Int , String ) -- this type should be a product type but it is not
37
37
| DuplicatedLinearVariable String (Int , Int , String ) -- this linear variable is used more than once
38
- | NotALinearFunction String (Int , Int , String ) -- this function is used more than once despite being declared linear
38
+ | NotALinearFunction String (Int , Int , String ) -- this function is used more than once despite not being declared linear
39
39
| NotALinearTerm Term Type (Int , Int , String ) -- this term should be linear but is is not
40
40
| NoCommonSupertype Type Type (Int , Int , String ) -- these two types have no common supertype
41
41
deriving (Eq , Ord , Read )
@@ -54,8 +54,8 @@ instance Show TypeError where
54
54
55
55
show (DuplicatedLinearVariable var (line, _, fname)) = " The linear variable '" ++ var ++ " ' in the top level function named: '" ++ fname ++ " ' defined at line: " ++ show line ++ " is used more than once"
56
56
57
- show (NotALinearFunction fun (line, _, fname)) = " The function named: '" ++ show fun ++ " ' which is used in the top level function named: '" ++ fname
58
- ++ " . defined at line: " ++ show line ++ " is used more than once despite being declared linear"
57
+ show (NotALinearFunction fun (line, _, fname)) = " The function named: '" ++ fun ++ " ' which is used in the top level function named: '" ++ fname
58
+ ++ " ' defined at line: " ++ show line ++ " is used more than once despite not being declared linear"
59
59
60
60
show (NotALinearTerm term typ (line, _, fname)) = " Term: '" ++ show term ++ " ' having as type: " ++ show typ
61
61
++ " which occurs in function " ++ fname ++ " defined at line: " ++ show line ++ " is not linear"
@@ -75,7 +75,7 @@ data ErrorEnvironment = ErrorEnvironment
75
75
76
76
type Check = Control.Monad.Except. ExceptT TypeError (Control.Monad.Reader. ReaderT MainEnvironment (Control.Monad.State. State ErrorEnvironment ))
77
77
78
- runTypeChecker :: Program -> Either Common. ErrorMessage Program
78
+ runTypeChecker :: Program -> Either String Program
79
79
runTypeChecker program =
80
80
case Control.Monad.State. evalState (Control.Monad.Reader. runReaderT (Control.Monad.Except. runExceptT (typeCheckProgram program)) mainEnv) errorEnv of
81
81
Left err -> Left (show err)
@@ -97,17 +97,23 @@ typeCheckFunction :: Function -> Check ()
97
97
typeCheckFunction (Function functionName (line, col) functionType term) = do
98
98
Control.Monad.State.Class. modify $ \ x -> x {currentFunction = functionName}
99
99
inferredType <- inferType [] term (line, col, functionName)
100
- if isSubtype inferredType functionType
100
+ if isSubtype inferredType functionType
101
101
then return ()
102
102
else Control.Monad.Except. throwError (TypeMismatch functionType inferredType (line, col, functionName))
103
103
104
+ -- typesMatch :: Type -> Type -> Bool
105
+ -- typesMatch tl tr = (tl' == tr') || isSubtype tl tr
106
+ -- where
107
+ -- tl' = removeBangs $ pullOutBangs tl
108
+ -- tr' = removeBangs $ pullOutBangs tr
109
+
104
110
isSubtype :: Type -> Type -> Bool
111
+ isSubtype (TypeNonLinear t1) (TypeNonLinear t2) = isSubtype (TypeNonLinear t1) t2
105
112
isSubtype (TypeNonLinear t1) t2 = isSubtype t1 t2
106
- isSubtype _ (TypeNonLinear _) = False
107
- isSubtype (TypeList t1) (TypeList t2) = isSubtype t1 t2
108
113
isSubtype (t1 :->: t2) (t1' :->: t2') = isSubtype t1' t1 && isSubtype t2 t2'
109
114
isSubtype (t1 :*: t2) (t1' :*: t2') = isSubtype t1 t1' && isSubtype t2 t2'
110
115
isSubtype (t1 :**: n1) (t2 :**: n2) = n1 == n2 && isSubtype t1 t2
116
+ isSubtype (TypeList t1) (TypeList t2) = isSubtype t1 t2
111
117
isSubtype t1 t2 = t1 == t2
112
118
113
119
inferType :: [Type ] -> Term -> (Int , Int , String ) -> Check Type
@@ -119,9 +125,21 @@ inferType _ (TermPower _) _ = return $ TypeNonLinear (TypeQbits :->: TypeQbits)
119
125
inferType _ (TermInverse _) _ = return $ TypeNonLinear (TypeQbits :->: TypeQbits )
120
126
inferType _ (TermBit _) _ = return $ TypeNonLinear TypeBit
121
127
inferType _ (TermGate gate) _ = return $ inferGateType gate
128
+ inferType _ (TermBool _) _ = return $ TypeNonLinear TypeBool
129
+ inferType _ (TermInteger _) _ = return $ TypeNonLinear TypeInteger
122
130
inferType _ TermUnit _ = return $ TypeNonLinear TypeUnit
123
131
inferType _ (TermBasisState _) _ = return TypeBasisState
124
132
133
+ inferType context (TermLambda typ term) (line, col, fname) = do
134
+ mainEnv <- Control.Monad.Reader. ask
135
+ checkLinearExpression term typ (line, col, fname)
136
+ termTyp <- inferType (typ: context) term (line, col, fname)
137
+ let boundedLinearVars = any (isLinear . (context !! ) . fromIntegral ) (freeVariables (TermLambda typ term))
138
+ let freeLinearVars = any isLinear $ Data.Maybe. mapMaybe (`Data.Map.lookup` mainEnv) (extractFunctionNames term)
139
+ if boundedLinearVars || freeLinearVars
140
+ then return (typ :->: termTyp)
141
+ else return $ TypeNonLinear (typ :->: termTyp)
142
+
125
143
inferType context (TermApply termLeft termRight) (line, col, fname) = do
126
144
leftTermType <- inferType context termLeft (line, col, fname)
127
145
rightTermType <- inferType context termRight (line, col, fname)
@@ -131,6 +149,16 @@ inferType context (TermApply termLeft termRight) (line, col, fname) = do
131
149
| otherwise -> Control.Monad.Except. throwError $ TypeMismatch argsType rightTermType (line, col, fname)
132
150
_ -> Control.Monad.Except. throwError $ NotAFunction leftTermType (line, col, fname)
133
151
152
+ inferType context (TermTuple l [r]) (line, col, fname) = do
153
+ leftTyp <- inferType context l (line, col, fname)
154
+ rightTyp <- inferType context r (line, col, fname)
155
+ return $ simplifyTensorProd $ pullOutBangs (leftTyp :*: rightTyp)
156
+
157
+ inferType context (TermTuple l (r: rs)) (line, col, fname) = do
158
+ leftTyp <- inferType context l (line, col, fname)
159
+ rightTyp <- inferType context (TermTuple r rs) (line, col, fname)
160
+ return $ simplifyTensorProd $ pullOutBangs (leftTyp :*: rightTyp)
161
+
134
162
inferType _ (TermFreeVariable var) (line, col, fname) = do
135
163
mainEnv <- Control.Monad.Reader. ask
136
164
linearEnv <- Control.Monad.State. gets linearEnvironment
@@ -172,4 +200,60 @@ isLinear _ = True
172
200
173
201
removeBangs :: Type -> Type
174
202
removeBangs (TypeNonLinear t) = removeBangs t
175
- removeBangs t = t
203
+ removeBangs t = t
204
+
205
+ pullOutBangs :: Type -> Type
206
+ pullOutBangs (TypeNonLinear l :*: TypeNonLinear r) = TypeNonLinear (pullOutBangs (l :*: r))
207
+ pullOutBangs (TypeNonLinear t :**: n) = TypeNonLinear (pullOutBangs (t :**: n))
208
+ pullOutBangs t = t
209
+
210
+ checkLinearExpression :: Term -> Type -> (Int , Int , String ) -> Check ()
211
+ checkLinearExpression term typ (line, col, fname) = case typ of
212
+ TypeNonLinear _ -> return ()
213
+ t -> if headBoundVariableCount term <= 1
214
+ then return ()
215
+ else Control.Monad.Except. throwError $ NotALinearTerm term t (line, col, fname)
216
+
217
+ headBoundVariableCount :: Term -> Integer
218
+ headBoundVariableCount = headBoundVariableCount' 0
219
+ where
220
+ headBoundVariableCount' :: Integer -> Term -> Integer
221
+ headBoundVariableCount' cnt term = case term of
222
+ TermBoundVariable i -> if cnt == i then 1 else 0
223
+ TermLambda _ lambdaTerm -> headBoundVariableCount' (cnt + 1 ) lambdaTerm
224
+ TermApply termLeft termRight -> headBoundVariableCount' cnt termLeft + headBoundVariableCount' cnt termRight
225
+ TermCompose termLeft termRight -> headBoundVariableCount' cnt termLeft + headBoundVariableCount' cnt termRight
226
+ TermDollar termLeft termRight -> headBoundVariableCount' cnt termLeft + headBoundVariableCount' cnt termRight
227
+ TermIfElse cond t f -> headBoundVariableCount' cnt cond + max (headBoundVariableCount' cnt t) (headBoundVariableCount' cnt f)
228
+ -- TermTuple left right -> headBoundVariableCount' cnt left + headBoundVariableCount' cnt right
229
+ -- TermList ListNil -> 0
230
+ -- TermLetSingle termEq termIn -> headBoundVariableCount' cnt termEq + headBoundVariableCount' (cnt + 1) termIn --TODO: verify
231
+ -- TermLetSugarSingle termEq termIn -> headBoundVariableCount' cnt termEq + headBoundVariableCount' (cnt + 1) termIn --TODO: verify
232
+ -- TermLetMultiple termEq termIn -> headBoundVariableCount' cnt termEq + headBoundVariableCount' (cnt + 2) termIn
233
+ -- TermLetSugarMultiple termEq termIn -> headBoundVariableCount' cnt termEq + headBoundVariableCount' (cnt + 2) termIn
234
+ _ -> 0
235
+
236
+ freeVariables :: Term -> [Integer ]
237
+ freeVariables = freeVariables' 0
238
+ where
239
+ freeVariables' :: Integer -> Term -> [Integer ]
240
+ -- freeVariables' cnt (TermTuple left right) = freeVariables' cnt left ++ freeVariables' cnt right
241
+ freeVariables' cnt (TermApply termLeft termRight) = freeVariables' cnt termLeft ++ freeVariables' cnt termRight
242
+ -- freeVariables' cnt (TermLetSingle termEq termIn) = freeVariables' cnt termEq ++ freeVariables' (cnt + 1) termIn --TODO: verify
243
+ -- freeVariables' cnt (TermLetSugarSingle termEq termIn) = freeVariables' cnt termEq ++ freeVariables' (cnt + 1) termIn --TODO: verify
244
+ -- freeVariables' cnt (TermLetMultiple termEq termIn) = freeVariables' cnt termEq ++ freeVariables' (cnt + 2) termIn
245
+ -- freeVariables' cnt (TermLetSugarMultiple termEq termIn) = freeVariables' cnt termEq ++ freeVariables' (cnt + 2) termIn
246
+ freeVariables' cnt (TermLambda _ lambdaTerm) = freeVariables' (cnt + 1 ) lambdaTerm
247
+ freeVariables' cnt (TermBoundVariable i) = [i - cnt | i >= cnt]
248
+ freeVariables' _ _ = []
249
+
250
+ extractFunctionNames :: Term -> [String ]
251
+ -- extractFunctionNames (TermTuple left right) = extractFunctionNames left ++ extractFunctionNames right
252
+ extractFunctionNames (TermApply termLeft termRight) = extractFunctionNames termLeft ++ extractFunctionNames termRight
253
+ -- extractFunctionNames (TermLetSingle termEq termIn) = extractFunctionNames termEq ++ extractFunctionNames termIn
254
+ -- extractFunctionNames (TermLetSugarSingle termEq termIn) = extractFunctionNames termEq ++ extractFunctionNames termIn
255
+ -- extractFunctionNames (TermLetMultiple termEq termIn) = extractFunctionNames termEq ++ extractFunctionNames termIn
256
+ -- extractFunctionNames (TermLetSugarMultiple termEq termIn) = extractFunctionNames termEq ++ extractFunctionNames termIn
257
+ extractFunctionNames (TermLambda _ lambdaTerm) = extractFunctionNames lambdaTerm
258
+ extractFunctionNames (TermFreeVariable fun) = [fun]
259
+ extractFunctionNames _ = []
0 commit comments