Skip to content

Commit f1ab1bc

Browse files
committed
Support of new training process
1 parent f7129f6 commit f1ab1bc

File tree

8 files changed

+573
-323
lines changed

8 files changed

+573
-323
lines changed

VSharp.API/VSharp.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,10 @@ private static Statistics StartExploration(
191191
stopOnCoverageAchieved: 100,
192192
randomSeed: options.RandomSeed,
193193
stepsLimit: options.StepsLimit,
194-
aiAgentTrainingOptions: options.AIAgentTrainingOptions == null ? FSharpOption<AIAgentTrainingOptions>.None : FSharpOption<AIAgentTrainingOptions>.Some(options.AIAgentTrainingOptions),
194+
aiOptions: options.AIOptions == null ? FSharpOption<AIOptions>.None : FSharpOption<AIOptions>.Some(options.AIOptions),
195195
pathToModel: options.PathToModel == null ? FSharpOption<string>.None : FSharpOption<string>.Some(options.PathToModel),
196-
useGPU: options.UseGPU == null ? FSharpOption<bool>.None : FSharpOption<bool>.Some(options.UseGPU),
197-
optimize: options.Optimize == null ? FSharpOption<bool>.None : FSharpOption<bool>.Some(options.Optimize)
196+
useGPU: options.UseGPU,
197+
optimize: options.Optimize
198198
);
199199

200200
var fuzzerOptions =

VSharp.API/VSharpOptions.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public readonly record struct VSharpOptions
113113
public readonly bool ReleaseBranches = DefaultReleaseBranches;
114114
public readonly int RandomSeed = DefaultRandomSeed;
115115
public readonly uint StepsLimit = DefaultStepsLimit;
116-
public readonly AIAgentTrainingOptions AIAgentTrainingOptions = null;
116+
public readonly AIOptions? AIOptions = null;
117117
public readonly string PathToModel = DefaultPathToModel;
118118
public readonly bool UseGPU = false;
119119
public readonly bool Optimize = false;
@@ -133,7 +133,7 @@ public readonly record struct VSharpOptions
133133
/// <param name="releaseBranches">If true and timeout is specified, a part of allotted time in the end is given to execute remaining states without branching.</param>
134134
/// <param name="randomSeed">Fixed seed for random operations. Used if greater than or equal to zero.</param>
135135
/// <param name="stepsLimit">Number of symbolic machine steps to stop execution after. Zero value means no limit.</param>
136-
/// <param name="aiAgentTrainingOptions">Settings for AI searcher training.</param>
136+
/// <param name="aiOptions">Settings for AI searcher training.</param>
137137
/// <param name="pathToModel">Path to ONNX file with model to use in AI searcher.</param>
138138
/// <param name="useGPU">Specifies whether the ONNX execution session should use a CUDA-enabled GPU.</param>
139139
/// <param name="optimize">Enabling options like parallel execution and various graph transformations to enhance performance of ONNX.</param>
@@ -150,7 +150,7 @@ public VSharpOptions(
150150
bool releaseBranches = DefaultReleaseBranches,
151151
int randomSeed = DefaultRandomSeed,
152152
uint stepsLimit = DefaultStepsLimit,
153-
AIAgentTrainingOptions aiAgentTrainingOptions = null,
153+
AIOptions? aiOptions = null,
154154
string pathToModel = DefaultPathToModel,
155155
bool useGPU = false,
156156
bool optimize = false)
@@ -167,7 +167,7 @@ public VSharpOptions(
167167
ReleaseBranches = releaseBranches;
168168
RandomSeed = randomSeed;
169169
StepsLimit = stepsLimit;
170-
AIAgentTrainingOptions = aiAgentTrainingOptions;
170+
AIOptions = aiOptions;
171171
PathToModel = pathToModel;
172172
UseGPU = useGPU;
173173
Optimize = optimize;

VSharp.Explorer/AISearcher.fs

Lines changed: 112 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,22 @@ namespace VSharp.Explorer
22

33
open System.Collections.Generic
44
open Microsoft.ML.OnnxRuntime
5+
open System
6+
open System.Text
7+
open System.Text.Json
58
open VSharp
69
open VSharp.IL.Serializer
710
open VSharp.ML.GameServer.Messages
811

9-
type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentTrainingOptions>) =
10-
let stepsToSwitchToAI =
11-
match aiAgentTrainingOptions with
12-
| None -> 0u<step>
13-
| Some options -> options.stepsToSwitchToAI
14-
15-
let stepsToPlay =
16-
match aiAgentTrainingOptions with
17-
| None -> 0u<step>
18-
| Some options -> options.stepsToPlay
19-
20-
let mutable lastCollectedStatistics = Statistics()
21-
let mutable defaultSearcherSteps = 0u<step>
22-
let mutable (gameState: Option<GameState>) = None
23-
let mutable useDefaultSearcher = stepsToSwitchToAI > 0u<step>
24-
let mutable afterFirstAIPeek = false
25-
let mutable incorrectPredictedStateId = false
26-
27-
let defaultSearcher =
28-
match aiAgentTrainingOptions with
29-
| None -> BFSSearcher() :> IForwardSearcher
30-
| Some options ->
31-
match options.defaultSearchStrategy with
32-
| BFSMode -> BFSSearcher() :> IForwardSearcher
33-
| DFSMode -> DFSSearcher() :> IForwardSearcher
34-
| x -> failwithf $"Unexpected default searcher {x}. DFS and BFS supported for now."
35-
36-
let mutable stepsPlayed = 0u<step>
37-
38-
let isInAIMode () =
39-
(not useDefaultSearcher) && afterFirstAIPeek
40-
41-
let q = ResizeArray<_>()
42-
let availableStates = HashSet<_>()
12+
type AIMode =
13+
| Runner
14+
| TrainingSendModel
15+
| TrainingSendEachStep
4316

44-
let updateGameState (delta: GameState) =
17+
module GameUtils =
18+
let updateGameState (delta: GameState) (gameState: Option<GameState>) =
4519
match gameState with
46-
| None -> gameState <- Some delta
20+
| None -> Some delta
4721
| Some s ->
4822
let updatedBasicBlocks = delta.GraphVertices |> Array.map (fun b -> b.Id) |> HashSet
4923
let updatedStates = delta.States |> Array.map (fun s -> s.Id) |> HashSet
@@ -86,14 +60,56 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
8660
s.Children |> Array.filter activeStates.Contains
8761
))
8862

89-
let pathConditionVertices =
90-
ResizeArray<PathConditionVertex> s.PathConditionVertices
63+
let pathConditionVertices = ResizeArray<PathConditionVertex> s.PathConditionVertices
9164

9265
pathConditionVertices.AddRange delta.PathConditionVertices
9366

94-
gameState <-
95-
Some
96-
<| GameState(vertices.ToArray(), states, pathConditionVertices.ToArray(), edges.ToArray())
67+
Some <| GameState(vertices.ToArray(), states, pathConditionVertices.ToArray(), edges.ToArray())
68+
69+
let convertOutputToJson (output: IDisposableReadOnlyCollection<OrtValue>) =
70+
seq { 0 .. output.Count - 1 }
71+
|> Seq.map (fun i -> output[i].GetTensorDataAsSpan<float32>().ToArray())
72+
73+
type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrainingMode>) =
74+
let stepsToSwitchToAI =
75+
match aiAgentTrainingMode with
76+
| None -> 0u<step>
77+
| Some(SendModel options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
78+
| Some(SendEachStep options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
79+
80+
let stepsToPlay =
81+
match aiAgentTrainingMode with
82+
| None -> 0u<step>
83+
| Some(SendModel options) -> options.aiAgentTrainingOptions.stepsToPlay
84+
| Some(SendEachStep options) -> options.aiAgentTrainingOptions.stepsToPlay
85+
86+
let mutable lastCollectedStatistics = Statistics()
87+
let mutable defaultSearcherSteps = 0u<step>
88+
let mutable (gameState: Option<GameState>) = None
89+
let mutable useDefaultSearcher = stepsToSwitchToAI > 0u<step>
90+
let mutable afterFirstAIPeek = false
91+
let mutable incorrectPredictedStateId = false
92+
93+
let defaultSearcher =
94+
let pickSearcher =
95+
function
96+
| BFSMode -> BFSSearcher() :> IForwardSearcher
97+
| DFSMode -> DFSSearcher() :> IForwardSearcher
98+
| x -> failwithf $"Unexpected default searcher {x}. DFS and BFS supported for now."
99+
100+
match aiAgentTrainingMode with
101+
| None -> BFSSearcher() :> IForwardSearcher
102+
| Some(SendModel options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
103+
| Some(SendEachStep options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
104+
105+
let mutable stepsPlayed = 0u<step>
106+
107+
let isInAIMode () =
108+
(not useDefaultSearcher) && afterFirstAIPeek
109+
110+
let q = ResizeArray<_>()
111+
let availableStates = HashSet<_>()
112+
97113

98114

99115
let init states =
@@ -128,15 +144,19 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
128144
for bb in state._history do
129145
bb.Key.AssociatedStates.Remove state |> ignore
130146

131-
let inTrainMode = aiAgentTrainingOptions.IsSome
147+
let aiMode =
148+
match aiAgentTrainingMode with
149+
| Some(SendEachStep _) -> TrainingSendEachStep
150+
| Some(SendModel _) -> TrainingSendModel
151+
| None -> Runner
132152

133153
let pick selector =
134154
if useDefaultSearcher then
135155
defaultSearcherSteps <- defaultSearcherSteps + 1u<step>
136156

137157
if Seq.length availableStates > 0 then
138158
let gameStateDelta = collectGameStateDelta ()
139-
updateGameState gameStateDelta
159+
gameState <- GameUtils.updateGameState gameStateDelta gameState
140160
let statistics = computeStatistics gameState.Value
141161
Application.applicationGraphDelta.Clear()
142162
lastCollectedStatistics <- statistics
@@ -149,7 +169,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
149169
Some(Seq.head availableStates)
150170
else
151171
let gameStateDelta = collectGameStateDelta ()
152-
updateGameState gameStateDelta
172+
gameState <- GameUtils.updateGameState gameStateDelta gameState
153173
let statistics = computeStatistics gameState.Value
154174

155175
if isInAIMode () then
@@ -158,14 +178,18 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
158178

159179
Application.applicationGraphDelta.Clear()
160180

161-
if inTrainMode && stepsToPlay = stepsPlayed then
181+
if stepsToPlay = stepsPlayed then
162182
None
163183
else
164184
let toPredict =
165-
if inTrainMode && stepsPlayed > 0u<step> then
166-
gameStateDelta
167-
else
168-
gameState.Value
185+
match aiMode with
186+
| TrainingSendEachStep
187+
| TrainingSendModel ->
188+
if stepsPlayed > 0u<step> then
189+
gameStateDelta
190+
else
191+
gameState.Value
192+
| Runner -> gameState.Value
169193

170194
let stateId = oracle.Predict toPredict
171195
afterFirstAIPeek <- true
@@ -180,12 +204,19 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
180204
oracle.Feedback(Feedback.IncorrectPredictedStateId stateId)
181205
None
182206

183-
new(pathToONNX: string, useGPU: bool, optimize: bool) =
207+
new
208+
(
209+
pathToONNX: string,
210+
useGPU: bool,
211+
optimize: bool,
212+
aiAgentTrainingModelOptions: Option<AIAgentTrainingModelOptions>
213+
) =
184214
let numOfVertexAttributes = 7
185215
let numOfStateAttributes = 7
186216
let numOfHistoryEdgeAttributes = 2
187217

188-
let createOracle (pathToONNX: string) =
218+
219+
let createOracleRunner (pathToONNX: string, aiAgentTrainingModelOptions: Option<AIAgentTrainingModelOptions>) =
189220
let sessionOptions =
190221
if useGPU then
191222
SessionOptions.MakeSessionOptionWithCudaProvider(0)
@@ -199,10 +230,21 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
199230
sessionOptions.GraphOptimizationLevel <- GraphOptimizationLevel.ORT_ENABLE_BASIC
200231

201232
let session = new InferenceSession(pathToONNX, sessionOptions)
233+
202234
let runOptions = new RunOptions()
203235
let feedback (x: Feedback) = ()
204236

205-
let predict (gameState: GameState) =
237+
let mutable stepsPlayed = 0
238+
let mutable currentGameState = None
239+
240+
let predict (gameStateOrDelta: GameState) =
241+
let _ =
242+
match aiAgentTrainingModelOptions with
243+
| Some _ when not (stepsPlayed = 0) ->
244+
currentGameState <- GameUtils.updateGameState gameStateOrDelta currentGameState
245+
| _ -> currentGameState <- Some gameStateOrDelta
246+
247+
let gameState = currentGameState.Value
206248
let stateIds = Dictionary<uint<stateId>, int>()
207249
let verticesIds = Dictionary<uint<basicBlockGlobalId>, int>()
208250

@@ -243,7 +285,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
243285
let j = i * numOfStateAttributes
244286
attributes.[j] <- float32 v.Position
245287
// TODO: Support path condition
246-
// attributes.[j + 1] <- float32 v.PathConditionSize
288+
// attributes.[j + 1] <- float32 v.PathConditionSize
247289
attributes.[j + 2] <- float32 v.VisitedAgainVertices
248290
attributes.[j + 3] <- float32 v.VisitedNotCoveredVerticesInZone
249291
attributes.[j + 4] <- float32 v.VisitedNotCoveredVerticesOutOfZone
@@ -350,14 +392,30 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
350392
res
351393

352394
let output = session.Run(runOptions, networkInput, session.OutputNames)
395+
396+
let _ =
397+
match aiAgentTrainingModelOptions with
398+
| Some aiAgentOptions ->
399+
aiAgentOptions.stepSaver (
400+
AIGameStep(gameState = gameStateOrDelta, output = GameUtils.convertOutputToJson output)
401+
)
402+
| None -> ()
403+
404+
stepsPlayed <- stepsPlayed + 1
405+
353406
let weighedStates = output[0].GetTensorDataAsSpan<float32>().ToArray()
354407

355408
let id = weighedStates |> Array.mapi (fun i v -> i, v) |> Array.maxBy snd |> fst
356409
stateIds |> Seq.find (fun kvp -> kvp.Value = id) |> (fun x -> x.Key)
357410

358411
Oracle(predict, feedback)
359412

360-
AISearcher(createOracle pathToONNX, None)
413+
let aiAgentTrainingOptions =
414+
match aiAgentTrainingModelOptions with
415+
| Some aiAgentTrainingModelOptions -> Some(SendModel aiAgentTrainingModelOptions)
416+
| None -> None
417+
418+
AISearcher(createOracleRunner (pathToONNX, aiAgentTrainingModelOptions), aiAgentTrainingOptions)
361419

362420
interface IForwardSearcher with
363421
override x.Init states = init states

0 commit comments

Comments
 (0)