Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Strata/Languages/Python/PySpecPipeline.lean
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ private def specDefaultToExpr : Python.Specs.SpecDefault → Python.expr SourceR
| .none => .Constant default (.ConNone default) default

/-- Convert a pyspec Arg to a PythonFunctionDecl arg tuple. -/
private def specArgToFuncDeclArg (arg : Python.Specs.Arg)
: String × String × Option (Python.expr SourceRange) :=
(arg.name, "Any", arg.default.map specDefaultToExpr)
private def specArgToFuncDeclArg (arg : Python.Specs.Arg): Python.PyArgInfo :=
{name:= arg.name, md:= default, tys:= ["Any"], default:= arg.default.map specDefaultToExpr}

/-- Build a PythonFunctionDecl from a PySpec FunctionDecl or class method,
expanding `**kwargs` TypedDict fields into individual parameters. -/
Expand Down
129 changes: 96 additions & 33 deletions Strata/Languages/Python/PythonToLaurel.lean
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,17 @@ inductive UnmodeledFunctionBehavior where
| havocInputsAndOutputs
deriving Inhabited

structure PyArgInfo where
name : String
md : MetaData
tys : List String
default : Option (Python.expr SourceRange)
deriving Repr

structure PythonFunctionDecl where
name : String
--args include name, type, default value
args : List (String × String × Option (Python.expr SourceRange))
args : List PyArgInfo
hasKwargs: Bool
ret : Option String
deriving Repr, Inhabited
Expand Down Expand Up @@ -181,10 +188,11 @@ partial def getSubscriptBaseName (e : Python.expr SourceRange) : String :=
| .Subscript _ val _ _ => getSubscriptBaseName val
| _ => pyExprToString e

def PyLauType.None := "None"
def PyLauType.Int := "int"
def PyLauType.Bool := "bool"
def PyLauType.Str := "str"
def PyLauType.Datetime := "Datetime"
def PyLauType.Datetime := "datetime"
def PyLauType.DictStrAny := "DictStrAny"
def PyLauType.ListStr := "ListStr"
def PyLauType.Package := "Package"
Expand Down Expand Up @@ -574,6 +582,7 @@ partial def inferExprType (ctx : TranslationContext) (e: Python.expr SourceRange
| .Constant _ (.ConFalse _) _
| .BoolOp _ _ _
| .Compare _ _ _ _=> return PyLauType.Bool
| .Constant _ (.ConNone _) _ => return PyLauType.None
-- Variable references
| .Name _ n _ =>
match ctx.variableTypes.find? (λ v => v.fst == n.val) with
Expand Down Expand Up @@ -681,7 +690,7 @@ partial def removePosargsFromKwargs (kwords : List (Python.keyword SourceRange))
kwords.filter (λ kw => match kw with
| .mk_keyword _ name _ =>
match name.val with
| some n => n.val ∉ funcDecl.args.unzip.fst
| some n => n.val ∉ funcDecl.args.map (·.name)
| none => true)

partial def combinePositionalAndKeywordArgs
Expand All @@ -703,18 +712,18 @@ partial def combinePositionalAndKeywordArgs
let kwords := pyKwordsToHashMap kwords
let unprovidedPosArgs := funcDecl.args.drop posArgs.length
--every unprovided positional args must have a default value in the function signature or be provided in the kwargs
let missingArgs := unprovidedPosArgs.filter fun (name, _, d) =>
!(name ∈ kwords.keys) && d.isNone
let missingArgs := unprovidedPosArgs.filter fun arg =>
!(arg.name ∈ kwords.keys) && arg.default.isNone
if missingArgs.length > 0 then
let missingNames := missingArgs.map (·.1)
throwUserError callRange s!"'{name}' called with missing required arguments: {missingNames}"
let filledPosArgs ←
unprovidedPosArgs.mapM (λ (argName, _, d) =>
match kwords.get? argName with
unprovidedPosArgs.mapM (λ arg =>
match kwords.get? arg.name with
| some expr => return expr
| none => match d with
| none => match arg.default with
| some val => return val
| _ => throw (.typeError s!"'{name}' missing required argument '{argName}'"))
| _ => throw (.typeError s!"'{name}' missing required argument '{arg.name}'"))
let posArgs := posArgs ++ filledPosArgs
return (posArgs, kwordArgs, funcDecl.hasKwargs)
| _ => return (posArgs, kwords, false)
Expand Down Expand Up @@ -785,8 +794,9 @@ def withException (ctx : TranslationContext) (funcname: String) : Bool :=
| some sig => sig.outputs.length > 0 && sig.outputs.getLast! == "Error"
| _ => false

def maybeExceptVar := mkStmtExprMd (.Identifier "maybe_except")
def nullcall_var := mkStmtExprMd (.Identifier "nullcall_ret")
def freeVar (name: String) := mkStmtExprMd (.Identifier name)
def maybeExceptVar := freeVar "maybe_except"
def nullcall_var := freeVar "nullcall_ret"

partial def translateAssign (ctx : TranslationContext)
(lhs: Python.expr SourceRange)
Expand Down Expand Up @@ -1154,45 +1164,59 @@ partial def getNestedSubscripts (expr: Python.expr SourceRange) : List ( Python
| .Subscript _ val slice _ => [val] ++ (getNestedSubscripts slice)
| _ => [expr]

partial def argumentTypeToString (arg: Python.expr SourceRange) : Except TranslationError String :=
def getUnionTypes (arg: Python.expr SourceRange) : List (Python.expr SourceRange) :=
match arg with
| .Name _ n _ => return n.val
| .BinOp _ left _ right => getUnionTypes left ++ getUnionTypes right
| _ => [arg]

partial def argumentTypeToString (arg: Python.expr SourceRange) : Except TranslationError (List String) :=
match arg with
| .Name _ n _ => return [n.val]
| .Subscript _ _ _ _ =>
let subscript_list:= getNestedSubscripts arg
let subscript_head := subscript_list[0]!
let slice_head := subscript_list[1]!
let v_name := pyExprToString subscript_head
match v_name with
| "Optional" => return "NoneOr" ++ pyExprToString slice_head
| _ => return v_name
| .Constant _ _ _ => return "None"
| .Attribute _ _ _ _ => return pyExprToString arg
| "Optional" => return [pyExprToString slice_head, "None"]
| "Union" => match slice_head with
| .Tuple _ tys _ => return (← tys.val.toList.mapM argumentTypeToString).flatten
| _ => throw (.internalError s!"Unhandled Expr: {repr arg}")
| _ => return [v_name]
| .Constant _ _ _ => return ["None"]
| .Attribute _ _ _ _ => return [pyExprToString arg]
| .BinOp _ _ _ _ => return (← (getUnionTypes arg).mapM argumentTypeToString).flatten
| _ => throw (.internalError s!"Unhandled Expr: {repr arg}")

--The return is a List (inputname, type, default value) and a bool indicating if the function has Kwargs input
def unpackPyArguments (args: Python.arguments SourceRange)
: Except TranslationError ((List (String × String × Option (Python.expr SourceRange))) × Bool):= do
def unpackPyArguments (ctx : TranslationContext) (args: Python.arguments SourceRange)
: Except TranslationError ((List PyArgInfo) × Bool):= do
match args with
| .mk_arguments _ _ args _ _ _ kwargs defaults =>
let argscount := args.val.size
let defaultscount := defaults.val.size;
let listdefaults := (List.range (argscount - defaultscount)).map (λ _ => none)
++ defaults.val.toList.map (λ x => some x)
let argsinfo := args.val.toList.zip listdefaults
let argtypes ←
let argtypes : List PyArgInfo
argsinfo.mapM (λ a: Python.arg SourceRange × Option (Python.expr SourceRange) =>
match a.fst with
| .mk_arg _ name oty _ =>
| .mk_arg sr name oty _ =>
let md := sourceRangeToMetaData ctx.filePath sr
match oty.val with
| .some ty => return (name.val, ← argumentTypeToString ty, a.snd)
| _ => return (name.val, PyLauType.Any, a.snd))
| .some ty =>
let defaultType := match a.snd.mapM (inferExprType ctx) with
| .ok (some ty) => [ty]
| _ => []
return {name:= name.val, md:=md, tys:=(← argumentTypeToString ty) ++ defaultType, default:= a.snd}
| _ => return {name:= name.val, md:= md, tys:=[PyLauType.Any], default:=a.snd})
return (argtypes, kwargs.val.isSome)

def pyFuncDefToPythonFunctionDecl (ctx : TranslationContext) (f : Python.stmt SourceRange) : Except TranslationError PythonFunctionDecl := do
match f with
| .FunctionDef _ name args _body _decorator_list returns _type_comment _ =>
let name := match ctx.currentClassName with | none => name.val | some classname => classname ++ "_" ++ name.val
let args_trans ← unpackPyArguments args
let args_trans ← unpackPyArguments ctx args
let args := if ctx.currentClassName.isSome then args_trans.fst.tail else args_trans.fst
let ret := if name.endsWith "@__init__" then some (name.dropEnd "@__init__".length).toString
else
Expand All @@ -1208,23 +1232,60 @@ def pyFuncDefToPythonFunctionDecl (ctx : TranslationContext) (f : Python.stmt S
}
| _ => throw (.internalError "Expected FunctionDef")

def getSingleTypeConstraint (var: String) (ty: String): Option StmtExprMd :=
match ty with
| "str" => mkStmtExprMd (.StaticCall "Any..isfrom_string" [freeVar var])
| "int" => mkStmtExprMd (.StaticCall "Any..isfrom_int" [freeVar var])
| "bool" => mkStmtExprMd (.StaticCall "Any..isfrom_bool" [freeVar var])
| "datetime" => mkStmtExprMd (.StaticCall "Any..isfrom_datetime" [freeVar var])
| "None" => mkStmtExprMd (.StaticCall "Any..isfrom_none" [freeVar var])
| _ => if ty.startsWith "Dict" then mkStmtExprMd (.StaticCall "Any..isfrom_Dict" [freeVar var])
else if ty.startsWith "List" then mkStmtExprMd (.StaticCall "Any..isfrom_ListAny" [freeVar var])
else none

def creatBoolOrExpr (exprs: List StmtExprMd) : StmtExprMd :=
match exprs with
| [] => mkStmtExprMd (.LiteralBool true)
| [expr] => expr
| expr::exprs => mkStmtExprMd (.StaticCall "Bool.Or" [expr, creatBoolOrExpr exprs])

def getUnionTypeConstraint (var: String) (md: MetaData) (tys: List String) (funcname: String): Option StmtExprMd :=
let type_constraints := tys.filterMap (getSingleTypeConstraint var)
if type_constraints.isEmpty then none else
let md: MetaData := md.withPropertySummary $ "(" ++ funcname ++ " requires) Type constraint of " ++ var
some {creatBoolOrExpr type_constraints with md:=md}

def getUnionTypeAssertions (var: String) (md: MetaData) (tys: List String) (funcname: String): Option StmtExprMd :=
match getUnionTypeConstraint var md tys funcname with
| some constraint =>
let md: MetaData := md.withPropertySummary $ "(" ++ funcname ++ " assert) Type constraint of " ++ var
mkStmtExprMdWithLoc (.Assert constraint) md
| _ => none

def getInputTypePreconditions (funcDecl : PythonFunctionDecl): List StmtExprMd :=
funcDecl.args.filterMap (λ arg => getUnionTypeConstraint arg.name arg.md arg.tys funcDecl.name)

def getInputTypecheckAssertions (funcDecl : PythonFunctionDecl): List StmtExprMd :=
funcDecl.args.filterMap (λ arg => getUnionTypeAssertions arg.name arg.md arg.tys funcDecl.name)

/-- Translate Python function to Laurel Procedure -/
def translateFunction (ctx : TranslationContext) (sourceRange: SourceRange) (funcDecl : PythonFunctionDecl) (body: List (Python.stmt SourceRange))
: Except TranslationError (Laurel.Procedure × TranslationContext) := do

-- Translate parameters
let mut inputs : List Parameter := []

inputs := funcDecl.args.map (fun (name, ty, _) =>
if ty ∈ ctx.compositeTypeNames then
{ name := name, type := mkHighTypeMd (.UserDefined ty) }
inputs := funcDecl.args.map (fun arg =>
if arg.tys.length == 1 && arg.tys[0]! ∈ ctx.compositeTypeNames then
{ name := arg.name, type := mkHighTypeMd (.UserDefined {text:= arg.tys[0]!}) }
else
{ name := name, type := AnyTy})
{ name := arg.name, type := AnyTy})
if funcDecl.hasKwargs then
let paramType ← translateType ctx PyLauType.DictStrAny
inputs:= inputs ++ [{ name := "kwargs", type := paramType }]

-- Translate return type
let typeConstraintAssertions := getInputTypecheckAssertions funcDecl
let typeConstraintPreconditions := getInputTypePreconditions funcDecl


-- Declare an output parameter when the function has a return type annotation.
Expand All @@ -1235,19 +1296,21 @@ def translateFunction (ctx : TranslationContext) (sourceRange: SourceRange) (fun
| some _ => [{ name := "LaurelResult", type := AnyTy }]

-- Translate function body
let inputTypes := funcDecl.args.map (λ (name, type, _) => (name, type))
let inputTypes := funcDecl.args.map (λ arg =>
match arg.tys with | [ty] => (arg.name, ty) | _ => (arg.name, PyLauType.Any))
let ctx := {ctx with variableTypes:= ("nullcall_ret", PyLauType.Any)::inputTypes}
let (newctx, bodyStmts) ← translateStmtList ctx body
let bodyStmts := prependExceptHandlingHelper bodyStmts
let bodyStmts := (mkStmtExprMd (.LocalVariable "nullcall_ret" AnyTy (some AnyNone))) :: bodyStmts
let bodyStmts := typeConstraintAssertions ++ bodyStmts
let bodyBlock := mkStmtExprMd (StmtExpr.Block bodyStmts none)

-- Create procedure with transparent body (no contracts for now)
let proc : Procedure := {
name := funcDecl.name
inputs := inputs
outputs := outputs
preconditions := []
preconditions := typeConstraintPreconditions
determinism := .deterministic none -- TODO: need to set reads
decreases := none
body := Body.Transparent bodyBlock
Expand Down Expand Up @@ -1293,7 +1356,7 @@ def preludeSignatureToPythonFunctionDecl (prelude : Core.Program) : List PythonF
let noneexpr : Python.expr SourceRange := .Constant default (.ConNone default) default
some {
name:= proc.header.name.name
args:= (inputnames.zip inputTypes).map (λ(n,t) => (n,t,noneexpr))
args:= (inputnames.zip inputTypes).map (λ(n,t) => {name:= n, md:= defaultMetadata, tys:= [t], default:= noneexpr})
hasKwargs := false
ret := if outputtypes.length == 0 then none else outputtypes[0]!
}
Expand Down Expand Up @@ -1542,7 +1605,7 @@ def PreludeInfo.ofLaurelProgram (prog : Laurel.Program) : PreludeInfo where
else
let noDefault : Option (Python.expr SourceRange) := none
let args := p.inputs.map fun param =>
(param.name.text, getHighTypeName param.type.val, noDefault)
{name:= param.name.text, md:= default, tys:= [getHighTypeName param.type.val], default:= noDefault}
let ret := p.outputs.head?.map fun param => getHighTypeName param.type.val
some { name := p.name.text, args := args, hasKwargs := false, ret := ret }
functions :=
Expand Down
3 changes: 2 additions & 1 deletion StrataMain.lean
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,8 @@ def pyAnalyzeLaurelCommand : Command where
("", "")
| none => ("", "")
let outcomeStr := vcResult.formatOutcome
s := s ++ s!"{locationPrefix}{vcResult.obligation.label}: \
let vcLabel := vcResult.obligation.metadata.getPropertySummary.getD vcResult.obligation.label
s := s ++ s!"{locationPrefix}{vcLabel}: \
{outcomeStr}{locationSuffix}\n"
IO.println s
-- Output in SARIF format if requested
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@

==== Verification Results ====
List_get_body_calls_List_get_0: ✅ pass
List_take_body_calls_List_take_0: ✅ pass
List_drop_body_calls_List_drop_0: ✅ pass
List_slice_body_calls_List_drop_0: ✅ pass
List_slice_body_calls_List_take_1: ✅ pass
List_set_body_calls_List_set_0: ✅ pass
DictStrAny_get_body_calls_DictStrAny_get_0: ✅ pass
Any_get_body_calls_DictStrAny_get_0: ✅ pass
Any_get_body_calls_List_get_1: ✅ pass
Any_get!_body_calls_DictStrAny_get_0: ✅ pass
Any_get!_body_calls_List_get_1: ✅ pass
Any_set_body_calls_List_set_0: ✅ pass
Any_set!_body_calls_List_set_0: ✅ pass
PFloorDiv_body_calls_Int.SafeDiv_0: ✅ pass
PFloorDiv_body_calls_Int.SafeDiv_1: ✅ pass
PFloorDiv_body_calls_Int.SafeDiv_2: ✅ pass
PFloorDiv_body_calls_Int.SafeDiv_3: ✅ pass
PAnd_body_calls_Any_to_bool_0: ✅ pass
POr_body_calls_Any_to_bool_0: ✅ pass
ret_type: ✅ pass (in prelude file)
ret_type: ✅ pass (in prelude file)
ret_pos: ✅ pass (in prelude file)
assert_name_is_foo: ✅ pass (in prelude file)
assert_opt_name_none_or_str: ✅ pass (in prelude file)
assert_opt_name_none_or_bar: ✅ pass (in prelude file)
ensures_maybe_except_none: ✅ pass (in prelude file)
(Mul assert) Type constraint of x: ✅ pass (at line 1, col 8)
(Mul assert) Type constraint of y: ✅ pass (at line 1, col 23)
(Sum assert) Type constraint of x: ✅ pass (at line 4, col 8)
(Sum assert) Type constraint of y: ✅ pass (at line 4, col 30)
ite_cond_calls_Any_to_bool_0: ✅ pass
(Mul requires) Type constraint of x: ✅ pass (at line 1, col 8)
(Mul requires) Type constraint of y: ✅ pass (at line 1, col 23)
(Sum requires) Type constraint of x: ✅ pass (at line 4, col 8)
(Sum requires) Type constraint of y: ✅ pass (at line 4, col 30)
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ assert_name_is_foo: ✅ pass (in prelude file)
assert_opt_name_none_or_str: ✅ pass (in prelude file)
assert_opt_name_none_or_bar: ✅ pass (in prelude file)
ensures_maybe_except_none: ✅ pass (in prelude file)
(my_f assert) Type constraint of s: ✅ pass (at line 5, col 9)
(Origin_test_helper_procedure_Requires)req_name_is_foo: ❓ unknown (in prelude file)
(Origin_test_helper_procedure_Requires)req_opt_name_none_or_str: ✅ pass (in prelude file)
(Origin_test_helper_procedure_Requires)req_opt_name_none_or_bar: ✅ pass (in prelude file)
(my_f requires) Type constraint of s: ✅ pass (at line 5, col 9)
ite_cond_calls_Any_to_bool_0: ✅ pass
2 changes: 1 addition & 1 deletion StrataTest/Languages/Python/run_py_analyze_sarif.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

BOTH_SKIP = {"test_foo_client_folder", "test_invalid_client_type", "test_unsupported_config"}
SKIP_TESTS = BOTH_SKIP | {"test_class_field_use", "test_list", "test_subscription", "test_with_statement", "test_class_field_init", "test_break_continue", "test_try_except", "test_try_except_scoping",
"test_augmented_assign"} # sarif pipeline uses PythonToCore which doesn't yet support AugAssign
"test_augmented_assign", "test_func_input_type_constraints"} # sarif pipeline uses PythonToCore which doesn't yet support AugAssign
SKIP_TESTS_LAUREL = BOTH_SKIP


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def Mul(x: int | bool, y: int | bool = "abc") -> int:
return x * y

def Sum(x: Union[int , bool], y: Union[int , bool] = None) -> int:
if y == None:
return x
return x + y

a = Mul(1, True)
a = Sum(1, None)
Loading