@@ -2,48 +2,22 @@ namespace VSharp.Explorer
22
33open System.Collections .Generic
44open Microsoft.ML .OnnxRuntime
5+ open System
6+ open System.Text
7+ open System.Text .Json
58open VSharp
69open VSharp.IL .Serializer
710open VSharp.ML .GameServer .Messages
811
9- type internal AISearcher ( oracle : Oracle , aiAgentTrainingOptions : Option < AIAgentTrainingOptions >) =
10- let stepsToSwitchToAI =
11- match aiAgentTrainingOptions with
12- | None -> 0 u< step>
13- | Some options -> options.stepsToSwitchToAI
14-
15- let stepsToPlay =
16- match aiAgentTrainingOptions with
17- | None -> 0 u< step>
18- | Some options -> options.stepsToPlay
19-
20- let mutable lastCollectedStatistics = Statistics()
21- let mutable defaultSearcherSteps = 0 u< step>
22- let mutable ( gameState : Option < GameState >) = None
23- let mutable useDefaultSearcher = stepsToSwitchToAI > 0 u< step>
24- let mutable afterFirstAIPeek = false
25- let mutable incorrectPredictedStateId = false
26-
27- let defaultSearcher =
28- match aiAgentTrainingOptions with
29- | None -> BFSSearcher() :> IForwardSearcher
30- | Some options ->
31- match options.defaultSearchStrategy with
32- | BFSMode -> BFSSearcher() :> IForwardSearcher
33- | DFSMode -> DFSSearcher() :> IForwardSearcher
34- | x -> failwithf $" Unexpected default searcher {x}. DFS and BFS supported for now."
35-
36- let mutable stepsPlayed = 0 u< step>
37-
38- let isInAIMode () =
39- ( not useDefaultSearcher) && afterFirstAIPeek
40-
41- let q = ResizeArray<_>()
42- let availableStates = HashSet<_>()
12+ type AIMode =
13+ | Runner
14+ | TrainingSendModel
15+ | TrainingSendEachStep
4316
44- let updateGameState ( delta : GameState ) =
17+ module GameUtils =
18+ let updateGameState ( delta : GameState ) ( gameState : Option < GameState >) =
4519 match gameState with
46- | None -> gameState <- Some delta
20+ | None -> Some delta
4721 | Some s ->
4822 let updatedBasicBlocks = delta.GraphVertices |> Array.map ( fun b -> b.Id) |> HashSet
4923 let updatedStates = delta.States |> Array.map ( fun s -> s.Id) |> HashSet
@@ -86,14 +60,56 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
8660 s.Children |> Array.filter activeStates.Contains
8761 ))
8862
89- let pathConditionVertices =
90- ResizeArray< PathConditionVertex> s.PathConditionVertices
63+ let pathConditionVertices = ResizeArray< PathConditionVertex> s.PathConditionVertices
9164
9265 pathConditionVertices.AddRange delta.PathConditionVertices
9366
94- gameState <-
95- Some
96- <| GameState( vertices.ToArray(), states, pathConditionVertices.ToArray(), edges.ToArray())
67+ Some <| GameState( vertices.ToArray(), states, pathConditionVertices.ToArray(), edges.ToArray())
68+
69+ let convertOutputToJson ( output : IDisposableReadOnlyCollection < OrtValue >) =
70+ seq { 0 .. output.Count - 1 }
71+ |> Seq.map ( fun i -> output[ i]. GetTensorDataAsSpan< float32>() .ToArray())
72+
73+ type internal AISearcher ( oracle : Oracle , aiAgentTrainingMode : Option < AIAgentTrainingMode >) =
74+ let stepsToSwitchToAI =
75+ match aiAgentTrainingMode with
76+ | None -> 0 u< step>
77+ | Some( SendModel options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
78+ | Some( SendEachStep options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
79+
80+ let stepsToPlay =
81+ match aiAgentTrainingMode with
82+ | None -> 0 u< step>
83+ | Some( SendModel options) -> options.aiAgentTrainingOptions.stepsToPlay
84+ | Some( SendEachStep options) -> options.aiAgentTrainingOptions.stepsToPlay
85+
86+ let mutable lastCollectedStatistics = Statistics()
87+ let mutable defaultSearcherSteps = 0 u< step>
88+ let mutable ( gameState : Option < GameState >) = None
89+ let mutable useDefaultSearcher = stepsToSwitchToAI > 0 u< step>
90+ let mutable afterFirstAIPeek = false
91+ let mutable incorrectPredictedStateId = false
92+
93+ let defaultSearcher =
94+ let pickSearcher =
95+ function
96+ | BFSMode -> BFSSearcher() :> IForwardSearcher
97+ | DFSMode -> DFSSearcher() :> IForwardSearcher
98+ | x -> failwithf $" Unexpected default searcher {x}. DFS and BFS supported for now."
99+
100+ match aiAgentTrainingMode with
101+ | None -> BFSSearcher() :> IForwardSearcher
102+ | Some( SendModel options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
103+ | Some( SendEachStep options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
104+
105+ let mutable stepsPlayed = 0 u< step>
106+
107+ let isInAIMode () =
108+ ( not useDefaultSearcher) && afterFirstAIPeek
109+
110+ let q = ResizeArray<_>()
111+ let availableStates = HashSet<_>()
112+
97113
98114
99115 let init states =
@@ -128,15 +144,19 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
128144 for bb in state._ history do
129145 bb.Key.AssociatedStates.Remove state |> ignore
130146
131- let inTrainMode = aiAgentTrainingOptions.IsSome
147+ let aiMode =
148+ match aiAgentTrainingMode with
149+ | Some( SendEachStep _) -> TrainingSendEachStep
150+ | Some( SendModel _) -> TrainingSendModel
151+ | None -> Runner
132152
133153 let pick selector =
134154 if useDefaultSearcher then
135155 defaultSearcherSteps <- defaultSearcherSteps + 1 u< step>
136156
137157 if Seq.length availableStates > 0 then
138158 let gameStateDelta = collectGameStateDelta ()
139- updateGameState gameStateDelta
159+ gameState <- GameUtils. updateGameState gameStateDelta gameState
140160 let statistics = computeStatistics gameState.Value
141161 Application.applicationGraphDelta.Clear()
142162 lastCollectedStatistics <- statistics
@@ -149,7 +169,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
149169 Some( Seq.head availableStates)
150170 else
151171 let gameStateDelta = collectGameStateDelta ()
152- updateGameState gameStateDelta
172+ gameState <- GameUtils. updateGameState gameStateDelta gameState
153173 let statistics = computeStatistics gameState.Value
154174
155175 if isInAIMode () then
@@ -158,14 +178,18 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
158178
159179 Application.applicationGraphDelta.Clear()
160180
161- if inTrainMode && stepsToPlay = stepsPlayed then
181+ if stepsToPlay = stepsPlayed then
162182 None
163183 else
164184 let toPredict =
165- if inTrainMode && stepsPlayed > 0 u< step> then
166- gameStateDelta
167- else
168- gameState.Value
185+ match aiMode with
186+ | TrainingSendEachStep
187+ | TrainingSendModel ->
188+ if stepsPlayed > 0 u< step> then
189+ gameStateDelta
190+ else
191+ gameState.Value
192+ | Runner -> gameState.Value
169193
170194 let stateId = oracle.Predict toPredict
171195 afterFirstAIPeek <- true
@@ -180,12 +204,19 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
180204 oracle.Feedback( Feedback.IncorrectPredictedStateId stateId)
181205 None
182206
183- new ( pathToONNX: string, useGPU: bool, optimize: bool) =
207+ new
208+ (
209+ pathToONNX: string,
210+ useGPU: bool,
211+ optimize: bool,
212+ aiAgentTrainingModelOptions: Option< AIAgentTrainingModelOptions>
213+ ) =
184214 let numOfVertexAttributes = 7
185215 let numOfStateAttributes = 7
186216 let numOfHistoryEdgeAttributes = 2
187217
188- let createOracle ( pathToONNX : string ) =
218+
219+ let createOracleRunner ( pathToONNX : string , aiAgentTrainingModelOptions : Option < AIAgentTrainingModelOptions >) =
189220 let sessionOptions =
190221 if useGPU then
191222 SessionOptions.MakeSessionOptionWithCudaProvider( 0 )
@@ -199,10 +230,21 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
199230 sessionOptions.GraphOptimizationLevel <- GraphOptimizationLevel.ORT_ ENABLE_ BASIC
200231
201232 let session = new InferenceSession( pathToONNX, sessionOptions)
233+
202234 let runOptions = new RunOptions()
203235 let feedback ( x : Feedback ) = ()
204236
205- let predict ( gameState : GameState ) =
237+ let mutable stepsPlayed = 0
238+ let mutable currentGameState = None
239+
240+ let predict ( gameStateOrDelta : GameState ) =
241+ let _ =
242+ match aiAgentTrainingModelOptions with
243+ | Some _ when not ( stepsPlayed = 0 ) ->
244+ currentGameState <- GameUtils.updateGameState gameStateOrDelta currentGameState
245+ | _ -> currentGameState <- Some gameStateOrDelta
246+
247+ let gameState = currentGameState.Value
206248 let stateIds = Dictionary< uint< stateId>, int>()
207249 let verticesIds = Dictionary< uint< basicBlockGlobalId>, int>()
208250
@@ -243,7 +285,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
243285 let j = i * numOfStateAttributes
244286 attributes.[ j] <- float32 v.Position
245287 // TODO: Support path condition
246- // attributes.[j + 1] <- float32 v.PathConditionSize
288+ // attributes.[j + 1] <- float32 v.PathConditionSize
247289 attributes.[ j + 2 ] <- float32 v.VisitedAgainVertices
248290 attributes.[ j + 3 ] <- float32 v.VisitedNotCoveredVerticesInZone
249291 attributes.[ j + 4 ] <- float32 v.VisitedNotCoveredVerticesOutOfZone
@@ -350,14 +392,30 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
350392 res
351393
352394 let output = session.Run( runOptions, networkInput, session.OutputNames)
395+
396+ let _ =
397+ match aiAgentTrainingModelOptions with
398+ | Some aiAgentOptions ->
399+ aiAgentOptions.stepSaver (
400+ AIGameStep( gameState = gameStateOrDelta, output = GameUtils.convertOutputToJson output)
401+ )
402+ | None -> ()
403+
404+ stepsPlayed <- stepsPlayed + 1
405+
353406 let weighedStates = output[ 0 ]. GetTensorDataAsSpan< float32>() .ToArray()
354407
355408 let id = weighedStates |> Array.mapi ( fun i v -> i, v) |> Array.maxBy snd |> fst
356409 stateIds |> Seq.find ( fun kvp -> kvp.Value = id) |> ( fun x -> x.Key)
357410
358411 Oracle( predict, feedback)
359412
360- AISearcher( createOracle pathToONNX, None)
413+ let aiAgentTrainingOptions =
414+ match aiAgentTrainingModelOptions with
415+ | Some aiAgentTrainingModelOptions -> Some( SendModel aiAgentTrainingModelOptions)
416+ | None -> None
417+
418+ AISearcher( createOracleRunner ( pathToONNX, aiAgentTrainingModelOptions), aiAgentTrainingOptions)
361419
362420 interface IForwardSearcher with
363421 override x.Init states = init states
0 commit comments