@@ -231,19 +231,16 @@ namespace mlir::rlc
231231 mlir::Type returnType,
232232 mlir::rlc::StreamWriter& writer,
233233 bool isMac,
234- bool isWindows)
234+ bool isWindows,
235+ llvm::SmallVector<std::string>& declaredFunNames)
235236 {
236- writer.write (" [DllImport(\" Lib." );
237- if (isMac)
238- writer.write (" dylib" );
239- else if (isWindows)
240- writer.write (" dll" );
241- else
242- writer.write (" so" );
243- writer.writenl (" \" )]" );
244- writer.write (" public static extern void " , mangledName);
237+ writer.write (" public delegate void Delegate" , mangledName);
245238 writeFunctionArgs (types, args, returnType, writer, true );
246239 writer.writenl (" ;" );
240+ writer.writenl (
241+ " public static Delegate" , mangledName, " " , mangledName, " ;" );
242+
243+ declaredFunNames.push_back (mangledName.str ());
247244 }
248245
249246 static void emitReturnVariable (mlir::Type returnType, StreamWriter& writer)
@@ -446,10 +443,14 @@ namespace mlir::rlc
446443 private:
447444 bool isMac;
448445 bool isWindows;
446+ llvm::SmallVector<std::string>& declaredFunNames;
449447
450448 public:
451- CSharpFunctionDeclarationMatcher (bool isMac, bool isWindows)
452- : isMac(isMac), isWindows(isWindows)
449+ CSharpFunctionDeclarationMatcher (
450+ bool isMac,
451+ bool isWindows,
452+ llvm::SmallVector<std::string>& declaredFunNames)
453+ : isMac(isMac), isWindows(isWindows), declaredFunNames(declaredFunNames)
453454 {
454455 }
455456 void apply (mlir::rlc::FunctionOp op, mlir::rlc::StreamWriter& writer)
@@ -464,7 +465,8 @@ namespace mlir::rlc
464465 getResultType (op.getFunctionType ()),
465466 writer,
466467 isMac,
467- isWindows);
468+ isWindows,
469+ declaredFunNames);
468470
469471 if (not op.getPrecondition ().empty ())
470472 declareFunction (
@@ -474,7 +476,8 @@ namespace mlir::rlc
474476 mlir::rlc::BoolType::get (op.getContext ()),
475477 writer,
476478 isMac,
477- isWindows);
479+ isWindows,
480+ declaredFunNames);
478481 }
479482 };
480483
@@ -484,11 +487,18 @@ namespace mlir::rlc
484487 mlir::rlc::ModuleBuilder& builder;
485488 bool isMac;
486489 bool isWindows;
490+ llvm::SmallVector<std::string>& declaredFunNames;
487491
488492 public:
489493 CSharpActionDeclarationMatcher (
490- mlir::rlc::ModuleBuilder& builder, bool isMac, bool isWindows)
491- : builder(builder), isMac(isMac), isWindows(isWindows)
494+ mlir::rlc::ModuleBuilder& builder,
495+ bool isMac,
496+ bool isWindows,
497+ llvm::SmallVector<std::string>& declaredFunNames)
498+ : builder(builder),
499+ isMac (isMac),
500+ isWindows(isWindows),
501+ declaredFunNames(declaredFunNames)
492502 {
493503 }
494504 void apply (mlir::rlc::ActionFunction op, mlir::rlc::StreamWriter& writer)
@@ -504,7 +514,8 @@ namespace mlir::rlc
504514 getResultType (op.getFunctionType ()),
505515 writer,
506516 isMac,
507- isWindows);
517+ isWindows,
518+ declaredFunNames);
508519 if (not op.getPrecondition ().empty ())
509520 {
510521 declareFunction (
@@ -514,7 +525,8 @@ namespace mlir::rlc
514525 mlir::rlc::BoolType::get (op.getContext ()),
515526 writer,
516527 isMac,
517- isWindows);
528+ isWindows,
529+ declaredFunNames);
518530 }
519531
520532 for (auto value : op.getActions ())
@@ -537,7 +549,8 @@ namespace mlir::rlc
537549 getResultType (fType ),
538550 writer,
539551 isMac,
540- isWindows);
552+ isWindows,
553+ declaredFunNames);
541554
542555 auto canDoType = mlir::FunctionType::get (
543556 fType .getContext (),
@@ -552,7 +565,8 @@ namespace mlir::rlc
552565 getResultType (canDoType),
553566 writer,
554567 isMac,
555- isWindows);
568+ isWindows,
569+ declaredFunNames);
556570 }
557571
558572 auto canFType = mlir::FunctionType::get (
@@ -567,7 +581,8 @@ namespace mlir::rlc
567581 getResultType (canFType),
568582 writer,
569583 isMac,
570- isWindows);
584+ isWindows,
585+ declaredFunNames);
571586 }
572587 };
573588
@@ -1036,6 +1051,121 @@ namespace mlir::rlc
10361051 writer.writenl (" }" ).endLine ();
10371052 }
10381053
1054+ static void emitSetTearDown (
1055+ llvm::SmallVector<std::string>& declaredFunNames, StreamWriter& writer)
1056+ {
1057+ writer.writenl (" internal static string SharedLibExtension =>" );
1058+ writer.writenl (
1059+ " RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? \" .dll\" :" );
1060+ writer.writenl (
1061+ " RuntimeInformation.IsOSPlatform(OSPlatform.OSX) ? \" .dylib\" :" );
1062+ writer.writenl (" /* default to Linux */ \" .so\" ;" );
1063+ writer.writenl (" private static IntPtr _lib;" );
1064+ auto _ = writer.indent ();
1065+ writer.writenl (" public static void setup(string libName) {" );
1066+ {
1067+ writer.write (" _lib = RLCNative.LoadLibrary(libName);" );
1068+ writer.write (
1069+ " if (_lib == IntPtr.Zero) throw new Exception(\" Could not find "
1070+ " library \" + libName );" );
1071+ auto _ = writer.indent ();
1072+ for (auto & exposedSymbol : declaredFunNames)
1073+ {
1074+ writer.writenl (
1075+ " IntPtr " ,
1076+ exposedSymbol,
1077+ " _ptr = GetProcAddress(_lib, \" " ,
1078+ exposedSymbol,
1079+ " \" );" );
1080+ writer.writenl (
1081+ " if (" ,
1082+ exposedSymbol,
1083+ " _ptr == IntPtr.Zero) throw new Exception(\" Could not find symbol " ,
1084+ exposedSymbol,
1085+ " \" );" );
1086+ writer.writenl (
1087+ exposedSymbol,
1088+ " = Marshal.GetDelegateForFunctionPointer<Delegate" ,
1089+ exposedSymbol,
1090+ " >(" ,
1091+ exposedSymbol,
1092+ " _ptr);" );
1093+ }
1094+ }
1095+ writer.writenl (" }" ).endLine ();
1096+
1097+ writer.writenl (" public static void teardown() {" );
1098+ {
1099+ writer.writenl (" if (_lib == IntPtr.Zero) return;" );
1100+ auto _ = writer.indent ();
1101+ for (auto & exposedSymbol : declaredFunNames)
1102+ {
1103+ writer.writenl (exposedSymbol, " = null;" );
1104+ }
1105+
1106+ writer.write (" RLCNative.FreeLibrary(_lib);" );
1107+ writer.writenl (" _lib = IntPtr.Zero;" );
1108+ }
1109+ writer.writenl (" }" ).endLine ();
1110+ }
1111+
1112+ static void emitDLLImporters (bool isMac, bool isWindows, StreamWriter& writer)
1113+ {
1114+ if (isMac)
1115+ {
1116+ writer.writenl (" const string LIBDL = \" libSystem.B.dylib\" ;" );
1117+ writer.writenl (" const int RTLD_NOW = 2;" );
1118+ writer.writenl (" [DllImport(LIBDL)] static extern IntPtr dlopen (string "
1119+ " path, int flags);" );
1120+ writer.writenl (
1121+ " [DllImport(LIBDL)] static extern int dlclose(IntPtr handle);" );
1122+ writer.writenl (" [DllImport(LIBDL)] static extern IntPtr dlsym (IntPtr "
1123+ " handle, string name);" );
1124+ writer.writenl (
1125+ " static IntPtr LoadLibrary (string p) => dlopen (p, RTLD_NOW);" );
1126+ writer.writenl (
1127+ " static bool FreeLibrary (IntPtr h) { dlclose(h); return true; }" );
1128+ writer.writenl (
1129+ " static IntPtr GetProcAddress(IntPtr h,string n)=>dlsym(h,n);" );
1130+ }
1131+ else if (isWindows)
1132+ {
1133+ writer.writenl (" const string KERNEL = \" kernel32\"\n " );
1134+ writer.writenl (" [DllImport(KERNEL, SetLastError = true)] static extern "
1135+ " IntPtr LoadLibrary(string path) " );
1136+ writer.writenl (" [DllImport(KERNEL, SetLastError = true)] static extern "
1137+ " bool FreeLibrary(IntPtr hModule);" );
1138+ writer.writenl (" [DllImport(KERNEL)] static extern "
1139+ " IntPtr GetProcAddress(IntPtr h, string name);" );
1140+ }
1141+ else
1142+ {
1143+ writer.writenl (" const string LIBDL = \" libdl.so.2\" ;" );
1144+ writer.writenl (" const int RTLD_NOW = 2;" );
1145+ writer.writenl (" [DllImport(LIBDL)] static extern IntPtr dlopen (string "
1146+ " path, int flags);" );
1147+ writer.writenl (
1148+ " [DllImport(LIBDL)] static extern int dlclose(IntPtr handle);" );
1149+ writer.writenl (" [DllImport(LIBDL)] static extern IntPtr dlsym (IntPtr "
1150+ " handle, string name);" );
1151+ writer.writenl (
1152+ " [DllImport(LIBDL, CharSet = CharSet.Ansi, ExactSpelling = true)]" );
1153+ writer.writenl (" static extern IntPtr dlerror();" );
1154+ writer.writenl (
1155+ " static IntPtr LoadLibrary (string p) => dlopen (p, RTLD_NOW);" );
1156+ writer.writenl (" static string DlLastError()" );
1157+ writer.writenl (" {" );
1158+ writer.writenl (" IntPtr p = dlerror();" );
1159+ writer.writenl (
1160+ " return p != IntPtr.Zero ? Marshal.PtrToStringAnsi(p) : null;" );
1161+ writer.writenl (" }" );
1162+ writer.writenl (
1163+ " static bool FreeLibrary (IntPtr h) { dlclose(h); return true; }" );
1164+ writer.writenl (
1165+ " static IntPtr GetProcAddress(IntPtr h,string n)=>dlsym(h,n);" );
1166+ }
1167+ }
1168+
10391169#define GEN_PASS_DEF_PRINTCSHARPPASS
10401170#include " rlc/dialect/Passes.inc"
10411171
@@ -1049,18 +1179,25 @@ namespace mlir::rlc
10491179 MemberFunctionsTable table (getOperation ());
10501180 mlir::rlc::ModuleBuilder builder (getOperation ());
10511181
1182+ llvm::SmallVector<std::string> declaredFunNames;
1183+
10521184 emitPrelude (matcher.getWriter ());
10531185 matcher.addTypeSerializer ();
10541186 registerTypeConversion (matcher.getWriter ().getTypeSerializer ());
10551187 registerTypeConversionRaw (matcher.getWriter ().getTypeSerializer (1 ));
1056- matcher.getWriter ().writenl (" unsafe class RLCNative {" );
1057- matcher.add <CSharpFunctionDeclarationMatcher>(isMac, isWindows);
1058- matcher.add <CSharpActionDeclarationMatcher>(builder, isMac, isWindows);
1188+ matcher.getWriter ().writenl (" public unsafe class RLCNative {" );
1189+ emitDLLImporters (isMac, isWindows, matcher.getWriter ());
1190+ matcher.add <CSharpFunctionDeclarationMatcher>(
1191+ isMac, isWindows, declaredFunNames);
1192+ matcher.add <CSharpActionDeclarationMatcher>(
1193+ builder, isMac, isWindows, declaredFunNames);
10591194 matcher.apply (getOperation ());
1195+ emitSetTearDown (declaredFunNames, matcher.getWriter ());
10601196 matcher.getWriter ().writenl (" }" ).endLine ();
10611197
10621198 matcher.clearMatchers ();
10631199 matcher.getWriter ().writenl (" unsafe class RLC {" );
1200+
10641201 matcher.add <CSharpFunctionWrappersMatcher>();
10651202 matcher.add <CSharpActionWrappersMatcher>();
10661203 matcher.apply (getOperation ());
0 commit comments