@@ -250,8 +250,7 @@ static mlir::rlc::ActionFunction deduceActionType(mlir::rlc::ActionFunction fun)
250250 generatedFunctions,
251251 fun.getUnmangledName (),
252252 fun.getInfo ());
253- if (fun->hasAttr (" emit_classes" ))
254- newAction->setAttr (" emit_classes" , rewriter.getUnitAttr ());
253+ newAction->setDiscardableAttrs (fun->getDiscardableAttrDictionary ());
255254 newAction.getBody ().takeBody (fun.getBody ());
256255 newAction.getPrecondition ().takeBody (fun.getPrecondition ());
257256 fun.getResult ().replaceAllUsesWith (newAction.getResult ());
@@ -556,8 +555,6 @@ static mlir::Type declareActionStatementType(
556555 name += snakeCaseToCamelCase (statement.getName ());
557556
558557 llvm::SmallVector<mlir::rlc::ClassFieldAttr, 4 > fields;
559- llvm::SmallVector<mlir::rlc::ClassFieldDeclarationAttr, 4 > fieldsAttrs;
560-
561558 for (auto [name, result] :
562559 llvm::zip (statement.getDeclaredNames (), statement.getResults ()))
563560 {
@@ -568,19 +565,21 @@ static mlir::Type declareActionStatementType(
568565 type = casted.getUnderlying ();
569566 auto field = mlir::rlc::ClassFieldAttr::get (name, type);
570567 fields.push_back (field);
571- fieldsAttrs.push_back (mlir::rlc::ClassFieldDeclarationAttr::get (field));
572568 }
573569
574570 auto type = mlir::rlc::ClassType::getNewIdentified (
575571 function.getContext (), name, fields, {});
576572
577573 auto built = builder.getRewriter ().create <mlir::rlc::ClassDeclaration>(
578- statement.getLoc (),
579- type,
580- name,
581- fieldsAttrs,
582- llvm::ArrayRef<mlir::Type>({}),
583- nullptr );
574+ statement.getLoc (), type, name, llvm::ArrayRef<mlir::Type>({}), nullptr );
575+ builder.getRewriter ().createBlock (&built.getBody ());
576+
577+ for (auto field : fields)
578+ {
579+ builder.getRewriter ().create <mlir::rlc::ClassFieldDeclaration>(
580+ built.getLoc (), mlir::rlc::ClassFieldDeclarationAttr::get (field));
581+ }
582+ builder.getRewriter ().setInsertionPointAfter (built);
584583 mlir::rlc::markSynthetic (built);
585584 auto applyFunction =
586585 defineApplyFunction (function, builder, action, statement, type);
@@ -1472,35 +1471,34 @@ mlir::LogicalResult mlir::rlc::ReturnStatement::typeCheck(
14721471 assert (yield);
14731472 const bool returnsValue = yield->getNumOperands () != 0 ;
14741473
1475- auto newOne = rewriter.create <mlir::rlc::ReturnStatement>(
1476- getLoc (),
1477- returnsValue ? yield->getOpOperand (0 ).get ().getType ()
1478- : mlir::rlc::VoidType::get (getContext ()));
1479- newOne.getBody ().takeBody (getBody ());
1480- rewriter.eraseOp (*this );
1474+ auto returnedType = returnsValue ? yield->getOpOperand (0 ).get ().getType ()
1475+ : mlir::rlc::VoidType::get (getContext ());
14811476
1482- if (newOne->getBlock ()->getTerminator () != newOne)
1477+ setResult (returnedType);
1478+
1479+ if (getOperation ()->getBlock ()->getTerminator () != getOperation ())
14831480 {
14841481 return mlir::rlc::logError (
1485- newOne ,
1482+ * this ,
14861483 " Return statement should be the last statement of its code block." );
14871484 }
14881485
1489- if (auto parentFunction = newOne->getParentOfType <mlir::rlc::FunctionOp>())
1486+ if (auto parentFunction =
1487+ getOperation ()->getParentOfType <mlir::rlc::FunctionOp>())
14901488 {
14911489 mlir::Type returnType =
14921490 (parentFunction.getType ().getNumResults () != 0
14931491 ? parentFunction.getResultTypes ()[0 ]
14941492 : mlir::rlc::VoidType::get (getContext ()));
14951493
1496- if (not isReturnTypeCompatible (newOne. getResult (), returnType))
1494+ if (not isReturnTypeCompatible (getResult (), returnType))
14971495 {
14981496 auto _ = mlir::rlc::logError (
1499- newOne ,
1497+ * this ,
15001498 " Return statement returns values incompatible with the function "
15011499 " signature" );
15021500 _ = mlir::rlc::logRemark (
1503- newOne , " Return value type is " + prettyType (newOne. getResult ()));
1501+ * this , " Return value type is " + prettyType (getResult ()));
15041502
15051503 return mlir::rlc::logRemark (
15061504 parentFunction,
@@ -2430,6 +2428,16 @@ mlir::LogicalResult mlir::rlc::FromByteArrayOp::typeCheck(
24302428 " are supported" );
24312429}
24322430
2431+ llvm::SmallVector<mlir::rlc::ShugarizedTypeAttr, 2 >
2432+ mlir::rlc::ClassFieldDeclaration::getShugarizedTypes ()
2433+ {
2434+ llvm::SmallVector<mlir::rlc::ShugarizedTypeAttr, 2 > toReturn;
2435+
2436+ if (getDeclaration ().getShugarizedType () != nullptr )
2437+ toReturn.push_back (getDeclaration ().getShugarizedType ());
2438+ return toReturn;
2439+ }
2440+
24332441llvm::SmallVector<mlir::rlc::ShugarizedTypeAttr, 2 >
24342442mlir::rlc::ConstantGlobalOp::getShugarizedTypes ()
24352443{
@@ -2498,17 +2506,6 @@ mlir::rlc::IsOp::getShugarizedTypes()
24982506 return { *getShugarizedType () };
24992507}
25002508
2501- llvm::SmallVector<mlir::rlc::ShugarizedTypeAttr, 2 >
2502- mlir::rlc::ClassDeclaration::getShugarizedTypes ()
2503- {
2504- llvm::SmallVector<mlir::rlc::ShugarizedTypeAttr, 2 > toReturn;
2505-
2506- for (auto parameter : getMemberFields ())
2507- if (parameter.getShugarizedType () != nullptr )
2508- toReturn.push_back (parameter.getShugarizedType ());
2509- return toReturn;
2510- }
2511-
25122509llvm::SmallVector<mlir::rlc::ShugarizedTypeAttr, 2 >
25132510mlir::rlc::TypeAliasOp::getShugarizedTypes ()
25142511{
@@ -3440,3 +3437,19 @@ size_t mlir::rlc::EnumDeclarationOp::countFields()
34403437 auto range = getBody ().front ().getOps <mlir::rlc::EnumFieldDeclarationOp>();
34413438 return std::distance (range.begin (), range.end ());
34423439}
3440+
3441+ llvm::SmallVector<mlir::rlc::ClassFieldDeclarationAttr, 4 >
3442+ mlir::rlc::ClassDeclaration::getMemberFields ()
3443+ {
3444+ llvm::SmallVector<mlir::rlc::ClassFieldDeclarationAttr, 4 > outs;
3445+ for (auto member : getBody ().getOps <mlir::rlc::ClassFieldDeclaration>())
3446+ outs.push_back (member.getDeclaration ());
3447+
3448+ return outs;
3449+ }
3450+
3451+ mlir::rlc::ClassFieldDeclarationAttr
3452+ mlir::rlc::ClassDeclaration::getMemberField (size_t i)
3453+ {
3454+ return getMemberFields ()[i];
3455+ }
0 commit comments