1
1
package cask .internal
2
2
3
3
import java .io .{InputStream , PrintWriter , StringWriter }
4
-
5
4
import scala .collection .generic .CanBuildFrom
6
5
import scala .collection .mutable
7
6
import java .io .OutputStream
8
-
7
+ import java .lang .invoke .{MethodHandles , MethodType }
8
+ import java .util .concurrent .{Executor , ExecutorService , ForkJoinPool , ThreadFactory }
9
9
import scala .annotation .switch
10
10
import scala .concurrent .{ExecutionContext , Future , Promise }
11
+ import scala .util .Try
12
+ import scala .util .control .NonFatal
11
13
12
14
object Util {
15
+ private val lookup = MethodHandles .lookup()
16
+
17
+ import cask .util .Logger .Console .globalLogger
18
+
19
+ /**
20
+ * Create a virtual thread executor with the given executor as the scheduler.
21
+ * */
22
+ def createVirtualThreadExecutor (executor : Executor ): Option [ExecutorService ] = {
23
+ (for {
24
+ factory <- Try (createVirtualThreadFactory(" cask-handler-executor" , executor))
25
+ executor <- Try (createNewThreadPerTaskExecutor(factory))
26
+ } yield executor).toOption
27
+ }
28
+
29
+ /**
30
+ * Create a default cask virtual thread executor if possible.
31
+ * */
32
+ def createDefaultCaskVirtualThreadExecutor : Option [ExecutorService ] = {
33
+ for {
34
+ scheduler <- getDefaultVirtualThreadScheduler
35
+ executor <- createVirtualThreadExecutor(scheduler)
36
+ } yield executor
37
+ }
38
+
39
+ /**
40
+ * Try to get the default virtual thread scheduler, or null if not supported.
41
+ * */
42
+ def getDefaultVirtualThreadScheduler : Option [ForkJoinPool ] = {
43
+ try {
44
+ val virtualThreadClass = Class .forName(" java.lang.VirtualThread" )
45
+ val privateLookup = MethodHandles .privateLookupIn(virtualThreadClass, lookup)
46
+ val defaultSchedulerField = privateLookup.findStaticVarHandle(virtualThreadClass, " DEFAULT_SCHEDULER" , classOf [ForkJoinPool ])
47
+ Option (defaultSchedulerField.get().asInstanceOf [ForkJoinPool ])
48
+ } catch {
49
+ case NonFatal (e) =>
50
+ // --add-opens java.base/java.lang=ALL-UNNAMED
51
+ globalLogger.exception(e)
52
+ None
53
+ }
54
+ }
55
+
56
+ def createNewThreadPerTaskExecutor (threadFactory : ThreadFactory ): ExecutorService = {
57
+ try {
58
+ val executorsClazz = ClassLoader .getSystemClassLoader.loadClass(" java.util.concurrent.Executors" )
59
+ val newThreadPerTaskExecutorMethod = lookup.findStatic(
60
+ executorsClazz,
61
+ " newThreadPerTaskExecutor" ,
62
+ MethodType .methodType(classOf [ExecutorService ], classOf [ThreadFactory ]))
63
+ newThreadPerTaskExecutorMethod.invoke(threadFactory)
64
+ .asInstanceOf [ExecutorService ]
65
+ } catch {
66
+ case NonFatal (e) =>
67
+ globalLogger.exception(e)
68
+ throw new UnsupportedOperationException (" Failed to create newThreadPerTaskExecutor." , e)
69
+ }
70
+ }
71
+
72
+ /**
73
+ * Create a virtual thread factory with a executor, the executor will be used as the scheduler of
74
+ * virtual thread.
75
+ *
76
+ * The executor should run task on platform threads.
77
+ *
78
+ * returns null if not supported.
79
+ */
80
+ def createVirtualThreadFactory (prefix : String ,
81
+ executor : Executor ): ThreadFactory =
82
+ try {
83
+ val builderClass = ClassLoader .getSystemClassLoader.loadClass(" java.lang.Thread$Builder" )
84
+ val ofVirtualClass = ClassLoader .getSystemClassLoader.loadClass(" java.lang.Thread$Builder$OfVirtual" )
85
+ val ofVirtualMethod = lookup.findStatic(classOf [Thread ], " ofVirtual" , MethodType .methodType(ofVirtualClass))
86
+ var builder = ofVirtualMethod.invoke()
87
+ if (executor != null ) {
88
+ val clazz = builder.getClass
89
+ val privateLookup = MethodHandles .privateLookupIn(
90
+ clazz,
91
+ lookup
92
+ )
93
+ val schedulerFieldSetter = privateLookup
94
+ .findSetter(clazz, " scheduler" , classOf [Executor ])
95
+ schedulerFieldSetter.invoke(builder, executor)
96
+ }
97
+ val nameMethod = lookup.findVirtual(ofVirtualClass, " name" ,
98
+ MethodType .methodType(ofVirtualClass, classOf [String ], classOf [Long ]))
99
+ val factoryMethod = lookup.findVirtual(builderClass, " factory" , MethodType .methodType(classOf [ThreadFactory ]))
100
+ builder = nameMethod.invoke(builder, prefix + " -virtual-thread-" , 0L )
101
+ factoryMethod.invoke(builder).asInstanceOf [ThreadFactory ]
102
+ } catch {
103
+ case NonFatal (e) =>
104
+ globalLogger.exception(e)
105
+ // --add-opens java.base/java.lang=ALL-UNNAMED
106
+ throw new UnsupportedOperationException (" Failed to create virtual thread factory." , e)
107
+ }
108
+
13
109
def firstFutureOf [T ](futures : Seq [Future [T ]])(implicit ec : ExecutionContext ) = {
14
110
val p = Promise [T ]
15
111
futures.foreach(_.foreach(p.trySuccess))
16
112
p.future
17
113
}
114
+
18
115
/**
19
- * Convert a string to a C&P-able literal. Basically
20
- * copied verbatim from the uPickle source code.
21
- */
116
+ * Convert a string to a C&P-able literal. Basically
117
+ * copied verbatim from the uPickle source code.
118
+ */
22
119
def literalize (s : IndexedSeq [Char ], unicode : Boolean = true ) = {
23
120
val sb = new StringBuilder
24
121
sb.append('"' )
@@ -47,29 +144,30 @@ object Util {
47
144
def transferTo (in : InputStream , out : OutputStream ) = {
48
145
val buffer = new Array [Byte ](8192 )
49
146
50
- while ({
51
- in.read(buffer) match {
147
+ while ( {
148
+ in.read(buffer) match {
52
149
case - 1 => false
53
150
case n =>
54
151
out.write(buffer, 0 , n)
55
152
true
56
153
}
57
154
}) ()
58
155
}
156
+
59
157
def pluralize (s : String , n : Int ) = {
60
158
if (n == 1 ) s else s + " s"
61
159
}
62
160
63
161
/**
64
- * Splits a string into path segments; automatically removes all
65
- * leading/trailing slashes, and ignores empty path segments.
66
- *
67
- * Written imperatively for performance since it's used all over the place.
68
- */
162
+ * Splits a string into path segments; automatically removes all
163
+ * leading/trailing slashes, and ignores empty path segments.
164
+ *
165
+ * Written imperatively for performance since it's used all over the place.
166
+ */
69
167
def splitPath (p : String ): collection.IndexedSeq [String ] = {
70
168
val pLength = p.length
71
169
var i = 0
72
- while (i < pLength && p(i) == '/' ) i += 1
170
+ while (i < pLength && p(i) == '/' ) i += 1
73
171
var segmentStart = i
74
172
val out = mutable.ArrayBuffer .empty[String ]
75
173
@@ -81,7 +179,7 @@ object Util {
81
179
segmentStart = i + 1
82
180
}
83
181
84
- while (i < pLength){
182
+ while (i < pLength) {
85
183
if (p(i) == '/' ) complete()
86
184
i += 1
87
185
}
@@ -96,33 +194,35 @@ object Util {
96
194
pw.flush()
97
195
trace.toString
98
196
}
197
+
99
198
def softWrap (s : String , leftOffset : Int , maxWidth : Int ) = {
100
199
val oneLine = s.linesIterator.mkString(" " ).split(' ' )
101
200
102
201
lazy val indent = " " * leftOffset
103
202
104
203
val output = new StringBuilder (oneLine.head)
105
204
var currentLineWidth = oneLine.head.length
106
- for (chunk <- oneLine.tail){
205
+ for (chunk <- oneLine.tail) {
107
206
val addedWidth = currentLineWidth + chunk.length + 1
108
- if (addedWidth > maxWidth){
207
+ if (addedWidth > maxWidth) {
109
208
output.append(" \n " + indent)
110
209
output.append(chunk)
111
210
currentLineWidth = chunk.length
112
- } else {
211
+ } else {
113
212
currentLineWidth = addedWidth
114
213
output.append(' ' )
115
214
output.append(chunk)
116
215
}
117
216
}
118
217
output.mkString
119
218
}
219
+
120
220
def sequenceEither [A , B , M [X ] <: TraversableOnce [X ]](in : M [Either [A , B ]])(
121
221
implicit cbf : CanBuildFrom [M [Either [A , B ]], B , M [B ]]): Either [A , M [B ]] = {
122
222
in.foldLeft[Either [A , mutable.Builder [B , M [B ]]]](Right (cbf(in))) {
123
- case (acc, el) =>
124
- for (a <- acc; e <- el) yield a += e
125
- }
223
+ case (acc, el) =>
224
+ for (a <- acc; e <- el) yield a += e
225
+ }
126
226
.map(_.result())
127
227
}
128
228
}
0 commit comments