-
Notifications
You must be signed in to change notification settings - Fork 45
Expand file tree
/
Copy pathChildProcess.kt
More file actions
320 lines (287 loc) · 11.4 KB
/
ChildProcess.kt
File metadata and controls
320 lines (287 loc) · 11.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
package org.utbot.instrumentation.process
import com.jetbrains.rd.framework.*
import com.jetbrains.rd.framework.impl.RdCall
import com.jetbrains.rd.util.ILoggerFactory
import com.jetbrains.rd.util.LogLevel
import com.jetbrains.rd.util.Logger
import com.jetbrains.rd.util.defaultLogFormat
import com.jetbrains.rd.util.lifetime.Lifetime
import com.jetbrains.rd.util.lifetime.LifetimeDefinition
import com.jetbrains.rd.util.lifetime.plusAssign
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeoutOrNull
import org.utbot.common.*
import org.utbot.framework.plugin.api.util.UtContext
import org.utbot.instrumentation.agent.Agent
import org.utbot.instrumentation.instrumentation.Instrumentation
import org.utbot.instrumentation.instrumentation.coverage.CoverageInstrumentation
import org.utbot.instrumentation.rd.childCreatedFileName
import org.utbot.instrumentation.rd.generated.CollectCoverageResult
import org.utbot.instrumentation.rd.generated.InvokeMethodCommandResult
import org.utbot.instrumentation.rd.generated.ProtocolModel
import org.utbot.instrumentation.rd.obtainClientIO
import org.utbot.instrumentation.rd.processSyncDirectory
import org.utbot.instrumentation.rd.signalChildReady
import org.utbot.instrumentation.util.KryoHelper
import org.utbot.rd.UtRdCoroutineScope
import org.utbot.rd.adviseForConditionAsync
import java.io.File
import java.io.OutputStream
import java.io.PrintStream
import java.net.URLClassLoader
import java.security.AllPermission
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
import java.util.concurrent.TimeUnit
import kotlin.system.measureTimeMillis
import org.utbot.framework.plugin.api.FieldId
import org.utbot.instrumentation.rd.generated.ComputeStaticFieldResult
/**
* We use this ClassLoader to separate user's classes and our dependency classes.
* Our classes won't be instrumented.
*/
private object HandlerClassesLoader : URLClassLoader(emptyArray()) {
fun addUrls(urls: Iterable<String>) {
urls.forEach { super.addURL(File(it).toURI().toURL()) }
}
/**
* System classloader can find org.slf4j thus when we want to mock something from org.slf4j
* we also want this class will be loaded by [HandlerClassesLoader]
*/
override fun loadClass(name: String, resolve: Boolean): Class<*> {
if (name.startsWith("org.slf4j")) {
return (findLoadedClass(name) ?: findClass(name)).apply {
if (resolve) resolveClass(this)
}
}
return super.loadClass(name, resolve)
}
}
private typealias ChildProcessLogLevel = LogLevel
private val logLevel = ChildProcessLogLevel.Info
// Logging
private val dateFormatter: DateTimeFormatter = DateTimeFormatter.ofPattern("HH:mm:ss.SSS")
private inline fun log(level: ChildProcessLogLevel, any: () -> Any?) {
if (level < logLevel)
return
System.err.println(LocalDateTime.now().format(dateFormatter) + " ${level.name.uppercase()}| ${any()}")
}
// errors that must be address
internal inline fun logError(any: () -> Any?) {
log(ChildProcessLogLevel.Error, any)
}
// default log level for irregular useful messages that does not pollute log
internal inline fun logInfo(any: () -> Any?) {
log(ChildProcessLogLevel.Info, any)
}
// log level for frequent messages useful for debugging
internal inline fun logDebug(any: () -> Any?) {
log(ChildProcessLogLevel.Debug, any)
}
// log level for internal rd logs and frequent messages
// heavily pollutes log, useful only when debugging rpc
// probably contains no info about utbot
internal fun logTrace(any: () -> Any?) {
log(ChildProcessLogLevel.Trace, any)
}
private enum class State {
STARTED,
ENDED
}
private val messageFromMainTimeoutMillis: Long = TimeUnit.SECONDS.toMillis(120)
private val synchronizer: Channel<State> = Channel(1)
/**
* Command-line option to disable the sandbox
*/
const val DISABLE_SANDBOX_OPTION = "--disable-sandbox"
/**
* It should be compiled into separate jar file (child_process.jar) and be run with an agent (agent.jar) option.
*/
suspend fun main(args: Array<String>) = runBlocking {
if (!args.contains(DISABLE_SANDBOX_OPTION)) {
permissions {
// Enable all permissions for instrumentation.
// SecurityKt.sandbox() is used to restrict these permissions.
+AllPermission()
}
}
// 0 - auto port for server, should not be used here
val port = args.find { it.startsWith(serverPortProcessArgumentTag) }
?.run { split("=").last().toInt().coerceIn(1..65535) }
?: throw IllegalArgumentException("No port provided")
val pid = currentProcessPid.toInt()
val def = LifetimeDefinition()
launch {
var lastState = State.STARTED
while (true) {
val current: State? =
withTimeoutOrNull(messageFromMainTimeoutMillis) {
synchronizer.receive()
}
if (current == null) {
if (lastState == State.ENDED) {
// process is waiting for command more than expected, better die
logInfo { "terminating lifetime" }
def.terminate()
break
}
}
else {
lastState = current
}
}
}
def.usingNested { lifetime ->
lifetime += { logInfo { "lifetime terminated" } }
try {
logInfo {"pid - $pid"}
logInfo {"isJvm8 - $isJvm8, isJvm9Plus - $isJvm9Plus, isWindows - $isWindows"}
initiate(lifetime, port, pid)
} finally {
val syncFile = File(processSyncDirectory, childCreatedFileName(pid))
if (syncFile.exists()) {
logInfo { "sync file existed" }
syncFile.delete()
}
}
}
}
private fun <T> measureExecutionForTermination(block: () -> T): T = runBlocking {
try {
synchronizer.send(State.STARTED)
return@runBlocking block()
}
finally {
synchronizer.send(State.ENDED)
}
}
private lateinit var pathsToUserClasses: Set<String>
private lateinit var pathsToDependencyClasses: Set<String>
private lateinit var instrumentation: Instrumentation<*>
private fun <T, R> RdCall<T, R>.measureExecutionForTermination(block: (T) -> R) {
this.set { it ->
runBlocking {
measureExecutionForTermination<R> {
try {
block(it)
} catch (e: Throwable) {
logError { e.stackTraceToString() }
throw e
}
}
}
}
}
private fun ProtocolModel.setup(kryoHelper: KryoHelper, onStop: () -> Unit) {
warmup.measureExecutionForTermination {
logDebug { "received warmup request" }
val time = measureTimeMillis {
HandlerClassesLoader.scanForClasses("").toList() // here we transform classes
}
logDebug { "warmup finished in $time ms" }
}
invokeMethodCommand.measureExecutionForTermination { params ->
logDebug { "received invokeMethod request: ${params.classname}, ${params.signature}" }
val clazz = HandlerClassesLoader.loadClass(params.classname)
val res = instrumentation.invoke(
clazz,
params.signature,
kryoHelper.readObject(params.arguments),
kryoHelper.readObject(params.parameters)
)
logDebug { "invokeMethod result: $res" }
InvokeMethodCommandResult(kryoHelper.writeObject(res))
}
setInstrumentation.measureExecutionForTermination { params ->
logDebug { "setInstrumentation request" }
instrumentation = kryoHelper.readObject(params.instrumentation)
logTrace { "instrumentation - ${instrumentation.javaClass.name} " }
Agent.dynamicClassTransformer.transformer = instrumentation // classTransformer is set
Agent.dynamicClassTransformer.addUserPaths(pathsToUserClasses)
instrumentation.init(pathsToUserClasses)
}
addPaths.measureExecutionForTermination { params ->
logDebug { "addPaths request" }
logTrace { "path to userClasses - ${params.pathsToUserClasses}"}
logTrace { "path to dependencyClasses - ${params.pathsToDependencyClasses}"}
pathsToUserClasses = params.pathsToUserClasses.split(File.pathSeparatorChar).toSet()
pathsToDependencyClasses = params.pathsToDependencyClasses.split(File.pathSeparatorChar).toSet()
HandlerClassesLoader.addUrls(pathsToUserClasses)
HandlerClassesLoader.addUrls(pathsToDependencyClasses)
kryoHelper.setKryoClassLoader(HandlerClassesLoader) // Now kryo will use our classloader when it encounters unregistered class.
UtContext.setUtContext(UtContext(HandlerClassesLoader))
}
stopProcess.measureExecutionForTermination {
logDebug { "stop request" }
onStop()
}
collectCoverage.measureExecutionForTermination { params ->
logDebug { "collect coverage request" }
val anyClass: Class<*> = kryoHelper.readObject(params.clazz)
logTrace { "class - ${anyClass.name}" }
val result = (instrumentation as CoverageInstrumentation).collectCoverageInfo(anyClass)
CollectCoverageResult(kryoHelper.writeObject(result))
}
computeStaticField.measureExecutionForTermination { params ->
val fieldId = kryoHelper.readObject<FieldId>(params.fieldId)
val result = instrumentation.getStaticField(fieldId)
ComputeStaticFieldResult(kryoHelper.writeObject(result))
}
}
private suspend fun initiate(lifetime: Lifetime, port: Int, pid: Int) {
// We don't want user code to litter the standard output, so we redirect it.
val tmpStream = PrintStream(object : OutputStream() {
override fun write(b: Int) {}
})
System.setOut(tmpStream)
Logger.set(lifetime, object : ILoggerFactory {
override fun getLogger(category: String) = object : Logger {
override fun isEnabled(level: LogLevel): Boolean {
return level >= logLevel
}
override fun log(level: LogLevel, message: Any?, throwable: Throwable?) {
val msg = defaultLogFormat(category, level, message, throwable)
log(logLevel) { msg }
}
}
})
val deferred = CompletableDeferred<Unit>()
lifetime.onTermination { deferred.complete(Unit) }
val kryoHelper = KryoHelper(lifetime)
logInfo { "kryo created" }
val clientProtocol = Protocol(
"ChildProcess",
Serializers(),
Identities(IdKind.Client),
UtRdCoroutineScope.scheduler,
SocketWire.Client(lifetime, UtRdCoroutineScope.scheduler, port),
lifetime
)
val (sync, protocolModel) = obtainClientIO(lifetime, clientProtocol)
protocolModel.setup(kryoHelper) {
deferred.complete(Unit)
}
signalChildReady(pid)
logInfo { "IO obtained" }
val answerFromMainProcess = sync.adviseForConditionAsync(lifetime) {
if (it == "main") {
logTrace { "received from main" }
measureExecutionForTermination {
sync.fire("child")
}
true
} else {
false
}
}
try {
answerFromMainProcess.await()
logInfo { "starting instrumenting" }
deferred.await()
} catch (e: Throwable) {
logError { "Terminating process because exception occurred: ${e.stackTraceToString()}" }
}
}