@@ -231,19 +231,16 @@ namespace mlir::rlc
231231 mlir::Type returnType,
232232 mlir::rlc::StreamWriter& writer,
233233 bool isMac,
234- bool isWindows)
234+ bool isWindows,
235+ llvm::SmallVector<std::string>& declaredFunNames)
235236 {
236- writer.write (" [DllImport(\" Lib." );
237- if (isMac)
238- writer.write (" dylib" );
239- else if (isWindows)
240- writer.write (" dll" );
241- else
242- writer.write (" so" );
243- writer.writenl (" \" )]" );
244- writer.write (" public static extern void " , mangledName);
237+ writer.write (" public delegate void Delegate" , mangledName);
245238 writeFunctionArgs (types, args, returnType, writer, true );
246239 writer.writenl (" ;" );
240+ writer.writenl (
241+ " public static Delegate" , mangledName, " " , mangledName, " ;" );
242+
243+ declaredFunNames.push_back (mangledName.str ());
247244 }
248245
249246 static void emitReturnVariable (mlir::Type returnType, StreamWriter& writer)
@@ -446,10 +443,14 @@ namespace mlir::rlc
446443 private:
447444 bool isMac;
448445 bool isWindows;
446+ llvm::SmallVector<std::string>& declaredFunNames;
449447
450448 public:
451- CSharpFunctionDeclarationMatcher (bool isMac, bool isWindows)
452- : isMac(isMac), isWindows(isWindows)
449+ CSharpFunctionDeclarationMatcher (
450+ bool isMac,
451+ bool isWindows,
452+ llvm::SmallVector<std::string>& declaredFunNames)
453+ : isMac(isMac), isWindows(isWindows), declaredFunNames(declaredFunNames)
453454 {
454455 }
455456 void apply (mlir::rlc::FunctionOp op, mlir::rlc::StreamWriter& writer)
@@ -464,7 +465,8 @@ namespace mlir::rlc
464465 getResultType (op.getFunctionType ()),
465466 writer,
466467 isMac,
467- isWindows);
468+ isWindows,
469+ declaredFunNames);
468470
469471 if (not op.getPrecondition ().empty ())
470472 declareFunction (
@@ -474,7 +476,8 @@ namespace mlir::rlc
474476 mlir::rlc::BoolType::get (op.getContext ()),
475477 writer,
476478 isMac,
477- isWindows);
479+ isWindows,
480+ declaredFunNames);
478481 }
479482 };
480483
@@ -484,11 +487,18 @@ namespace mlir::rlc
484487 mlir::rlc::ModuleBuilder& builder;
485488 bool isMac;
486489 bool isWindows;
490+ llvm::SmallVector<std::string>& declaredFunNames;
487491
488492 public:
489493 CSharpActionDeclarationMatcher (
490- mlir::rlc::ModuleBuilder& builder, bool isMac, bool isWindows)
491- : builder(builder), isMac(isMac), isWindows(isWindows)
494+ mlir::rlc::ModuleBuilder& builder,
495+ bool isMac,
496+ bool isWindows,
497+ llvm::SmallVector<std::string>& declaredFunNames)
498+ : builder(builder),
499+ isMac (isMac),
500+ isWindows(isWindows),
501+ declaredFunNames(declaredFunNames)
492502 {
493503 }
494504 void apply (mlir::rlc::ActionFunction op, mlir::rlc::StreamWriter& writer)
@@ -504,7 +514,8 @@ namespace mlir::rlc
504514 getResultType (op.getFunctionType ()),
505515 writer,
506516 isMac,
507- isWindows);
517+ isWindows,
518+ declaredFunNames);
508519 if (not op.getPrecondition ().empty ())
509520 {
510521 declareFunction (
@@ -514,7 +525,8 @@ namespace mlir::rlc
514525 mlir::rlc::BoolType::get (op.getContext ()),
515526 writer,
516527 isMac,
517- isWindows);
528+ isWindows,
529+ declaredFunNames);
518530 }
519531
520532 for (auto value : op.getActions ())
@@ -537,7 +549,8 @@ namespace mlir::rlc
537549 getResultType (fType ),
538550 writer,
539551 isMac,
540- isWindows);
552+ isWindows,
553+ declaredFunNames);
541554
542555 auto canDoType = mlir::FunctionType::get (
543556 fType .getContext (),
@@ -552,7 +565,8 @@ namespace mlir::rlc
552565 getResultType (canDoType),
553566 writer,
554567 isMac,
555- isWindows);
568+ isWindows,
569+ declaredFunNames);
556570 }
557571
558572 auto canFType = mlir::FunctionType::get (
@@ -567,7 +581,8 @@ namespace mlir::rlc
567581 getResultType (canFType),
568582 writer,
569583 isMac,
570- isWindows);
584+ isWindows,
585+ declaredFunNames);
571586 }
572587 };
573588
@@ -1036,6 +1051,115 @@ namespace mlir::rlc
10361051 writer.writenl (" }" ).endLine ();
10371052 }
10381053
1054+ static void emitSetTearDown (
1055+ llvm::SmallVector<std::string>& declaredFunNames, StreamWriter& writer)
1056+ {
1057+ writer.writenl (" private static IntPtr _lib;" );
1058+ auto _ = writer.indent ();
1059+ writer.writenl (" public static void setup(string libName) {" );
1060+ {
1061+ writer.write (" _lib = RLCNative.LoadLibrary(libName);" );
1062+ writer.write (
1063+ " if (_lib == IntPtr.Zero) throw new Exception(\" Could not find "
1064+ " library \" + libName );" );
1065+ auto _ = writer.indent ();
1066+ for (auto & exposedSymbol : declaredFunNames)
1067+ {
1068+ writer.writenl (
1069+ " IntPtr " ,
1070+ exposedSymbol,
1071+ " _ptr = GetProcAddress(_lib, \" " ,
1072+ exposedSymbol,
1073+ " \" );" );
1074+ writer.writenl (
1075+ " if (" ,
1076+ exposedSymbol,
1077+ " _ptr == IntPtr.Zero) throw new Exception(\" Could not find symbol " ,
1078+ exposedSymbol,
1079+ " \" );" );
1080+ writer.writenl (
1081+ exposedSymbol,
1082+ " = Marshal.GetDelegateForFunctionPointer<Delegate" ,
1083+ exposedSymbol,
1084+ " >(" ,
1085+ exposedSymbol,
1086+ " _ptr);" );
1087+ }
1088+ }
1089+ writer.writenl (" }" ).endLine ();
1090+
1091+ writer.writenl (" public static void teardown() {" );
1092+ {
1093+ writer.writenl (" if (_lib == IntPtr.Zero) return;" );
1094+ auto _ = writer.indent ();
1095+ for (auto & exposedSymbol : declaredFunNames)
1096+ {
1097+ writer.writenl (exposedSymbol, " = null;" );
1098+ }
1099+
1100+ writer.write (" RLCNative.FreeLibrary(_lib);" );
1101+ writer.writenl (" _lib = IntPtr.Zero;" );
1102+ }
1103+ writer.writenl (" }" ).endLine ();
1104+ }
1105+
1106+ static void emitDLLImporters (bool isMac, bool isWindows, StreamWriter& writer)
1107+ {
1108+ if (isMac)
1109+ {
1110+ writer.writenl (" const string LIBDL = \" libSystem.B.dylib\" ;" );
1111+ writer.writenl (" const int RTLD_NOW = 2;" );
1112+ writer.writenl (" [DllImport(LIBDL)] static extern IntPtr dlopen (string "
1113+ " path, int flags);" );
1114+ writer.writenl (
1115+ " [DllImport(LIBDL)] static extern int dlclose(IntPtr handle);" );
1116+ writer.writenl (" [DllImport(LIBDL)] static extern IntPtr dlsym (IntPtr "
1117+ " handle, string name);" );
1118+ writer.writenl (
1119+ " static IntPtr LoadLibrary (string p) => dlopen (p, RTLD_NOW);" );
1120+ writer.writenl (
1121+ " static bool FreeLibrary (IntPtr h) { dlclose(h); return true; }" );
1122+ writer.writenl (
1123+ " static IntPtr GetProcAddress(IntPtr h,string n)=>dlsym(h,n);" );
1124+ }
1125+ else if (isWindows)
1126+ {
1127+ writer.writenl (" const string KERNEL = \" kernel32\"\n " );
1128+ writer.writenl (" [DllImport(KERNEL, SetLastError = true)] static extern "
1129+ " IntPtr LoadLibrary(string path) " );
1130+ writer.writenl (" [DllImport(KERNEL, SetLastError = true)] static extern "
1131+ " bool FreeLibrary(IntPtr hModule);" );
1132+ writer.writenl (" [DllImport(KERNEL)] static extern "
1133+ " IntPtr GetProcAddress(IntPtr h, string name);" );
1134+ }
1135+ else
1136+ {
1137+ writer.writenl (" const string LIBDL = \" libdl.so.2\" ;" );
1138+ writer.writenl (" const int RTLD_NOW = 2;" );
1139+ writer.writenl (" [DllImport(LIBDL)] static extern IntPtr dlopen (string "
1140+ " path, int flags);" );
1141+ writer.writenl (
1142+ " [DllImport(LIBDL)] static extern int dlclose(IntPtr handle);" );
1143+ writer.writenl (" [DllImport(LIBDL)] static extern IntPtr dlsym (IntPtr "
1144+ " handle, string name);" );
1145+ writer.writenl (
1146+ " [DllImport(LIBDL, CharSet = CharSet.Ansi, ExactSpelling = true)]" );
1147+ writer.writenl (" static extern IntPtr dlerror();" );
1148+ writer.writenl (
1149+ " static IntPtr LoadLibrary (string p) => dlopen (p, RTLD_NOW);" );
1150+ writer.writenl (" static string DlLastError()" );
1151+ writer.writenl (" {" );
1152+ writer.writenl (" IntPtr p = dlerror();" );
1153+ writer.writenl (
1154+ " return p != IntPtr.Zero ? Marshal.PtrToStringAnsi(p) : null;" );
1155+ writer.writenl (" }" );
1156+ writer.writenl (
1157+ " static bool FreeLibrary (IntPtr h) { dlclose(h); return true; }" );
1158+ writer.writenl (
1159+ " static IntPtr GetProcAddress(IntPtr h,string n)=>dlsym(h,n);" );
1160+ }
1161+ }
1162+
10391163#define GEN_PASS_DEF_PRINTCSHARPPASS
10401164#include " rlc/dialect/Passes.inc"
10411165
@@ -1049,14 +1173,20 @@ namespace mlir::rlc
10491173 MemberFunctionsTable table (getOperation ());
10501174 mlir::rlc::ModuleBuilder builder (getOperation ());
10511175
1176+ llvm::SmallVector<std::string> declaredFunNames;
1177+
10521178 emitPrelude (matcher.getWriter ());
10531179 matcher.addTypeSerializer ();
10541180 registerTypeConversion (matcher.getWriter ().getTypeSerializer ());
10551181 registerTypeConversionRaw (matcher.getWriter ().getTypeSerializer (1 ));
1056- matcher.getWriter ().writenl (" unsafe class RLCNative {" );
1057- matcher.add <CSharpFunctionDeclarationMatcher>(isMac, isWindows);
1058- matcher.add <CSharpActionDeclarationMatcher>(builder, isMac, isWindows);
1182+ matcher.getWriter ().writenl (" public unsafe class RLCNative {" );
1183+ emitDLLImporters (isMac, isWindows, matcher.getWriter ());
1184+ matcher.add <CSharpFunctionDeclarationMatcher>(
1185+ isMac, isWindows, declaredFunNames);
1186+ matcher.add <CSharpActionDeclarationMatcher>(
1187+ builder, isMac, isWindows, declaredFunNames);
10591188 matcher.apply (getOperation ());
1189+ emitSetTearDown (declaredFunNames, matcher.getWriter ());
10601190 matcher.getWriter ().writenl (" }" ).endLine ();
10611191
10621192 matcher.clearMatchers ();
0 commit comments