@@ -21,10 +21,10 @@ use tokio::time::sleep_until;
2121use tokio:: time:: Instant ;
2222use uuid:: Uuid ;
2323
24- mod ip_version ;
24+ mod protocols ;
2525
26- use self :: ip_version :: IpVersion ;
27- use self :: ip_version :: IP_VERSIONS ;
26+ use self :: protocols :: Protocol ;
27+ pub use self :: protocols :: Protocols ;
2828
2929const INTERVAL_HEARTBEAT : Duration = Duration :: from_secs ( 15 ) ;
3030const INTERVAL_INFO_CHANGE : Duration = Duration :: from_secs ( 1 ) ;
@@ -137,7 +137,7 @@ struct RegisterTaskShared {
137137 // If you want to have both the `RegisterShared` and this lock, take the
138138 // `RegisterShared` lock first, to avoid deadlocks.
139139 data : Mutex < RegisterTaskData > ,
140- ip_version : IpVersion ,
140+ protocol : Protocol ,
141141 next_register_changed : Notify ,
142142}
143143
@@ -188,6 +188,7 @@ impl RegisterTaskData {
188188pub struct RegisterBuilder {
189189 require_external_heartbeats : bool ,
190190 register_url : Option < String > ,
191+ protocols : Option < Protocols > ,
191192 user_agent : Option < String > ,
192193 community_token : Option < String > ,
193194}
@@ -203,6 +204,11 @@ impl RegisterBuilder {
203204 self . register_url = Some ( register_url) ;
204205 self
205206 }
207+ pub fn protocols ( mut self , protocols : Protocols ) -> RegisterBuilder {
208+ assert ! ( self . protocols. is_none( ) ) ;
209+ self . protocols = Some ( protocols) ;
210+ self
211+ }
206212 pub fn user_agent ( mut self , user_agent : String ) -> RegisterBuilder {
207213 assert ! ( self . user_agent. is_none( ) ) ;
208214 self . user_agent = Some ( user_agent) ;
@@ -220,18 +226,18 @@ impl RegisterBuilder {
220226
221227pub struct Register {
222228 shared : Arc < RegisterShared > ,
223- tasks : [ Arc < RegisterTaskShared > ; 2 ] ,
229+ tasks : [ Option < Arc < RegisterTaskShared > > ; 2 ] ,
224230}
225231
226232async fn register_task ( shared : Arc < RegisterShared > , task : Arc < RegisterTaskShared > ) -> ! {
227233 let client = reqwest:: Client :: builder ( )
228234 . user_agent ( & * shared. user_agent )
229- . local_address ( task. ip_version . bind_all ( ) )
235+ . local_address ( task. protocol . bind_all_addr ( ) )
230236 . build ( )
231237 . unwrap ( ) ;
232238
233239 let challenge_secret: Box < str > =
234- format ! ( "{}:{}" , shared. challenge_secret, task. ip_version ) . into ( ) ;
240+ format ! ( "{}:{}" , shared. challenge_secret, task. protocol ) . into ( ) ;
235241
236242 loop {
237243 // send register
@@ -334,6 +340,7 @@ impl Register {
334340 RegisterBuilder {
335341 require_external_heartbeats,
336342 register_url,
343+ protocols,
337344 user_agent,
338345 community_token,
339346 } : RegisterBuilder ,
@@ -345,6 +352,7 @@ impl Register {
345352 } else {
346353 Some ( INTERVAL_HEARTBEAT )
347354 } ;
355+ let protocols = protocols. unwrap_or ( Protocols :: all ( ) ) ;
348356
349357 let challenge_secret = Uuid :: new_v4 ( ) . to_string ( ) ;
350358 let challenge_packet_prefix: Vec < u8 > = [
@@ -380,7 +388,10 @@ impl Register {
380388 } ) ;
381389
382390 let now = Instant :: now ( ) ;
383- let tasks = IP_VERSIONS . map ( |ip_version| {
391+ let tasks = protocols:: ALL . map ( |protocol| {
392+ if !protocols. contains ( protocol) {
393+ return None ;
394+ }
384395 let task = Arc :: new ( RegisterTaskShared {
385396 data : Mutex :: new ( RegisterTaskData {
386397 token : None ,
@@ -390,11 +401,11 @@ impl Register {
390401 prev_register : now,
391402 next_register : period. map ( |p| now + p) ,
392403 } ) ,
393- ip_version ,
404+ protocol ,
394405 next_register_changed : Notify :: new ( ) ,
395406 } ) ;
396407 let _ = tokio:: spawn ( register_task ( shared. clone ( ) , task. clone ( ) ) ) ;
397- task
408+ Some ( task)
398409 } ) ;
399410
400411 Register { shared, tasks }
@@ -408,21 +419,37 @@ impl Register {
408419 data. info = info;
409420
410421 // Lock all the task data once.
411- let mut task_data: Vec < _ > = self . tasks . iter ( ) . map ( |t| t. data . lock ( ) . unwrap ( ) ) . collect ( ) ;
422+ let mut task_data: Vec < _ > = self
423+ . tasks
424+ . iter ( )
425+ . enumerate ( )
426+ . filter_map ( |( i, t) | t. as_ref ( ) . map ( |t| ( i, t. data . lock ( ) . unwrap ( ) ) ) )
427+ . collect ( ) ;
428+
429+ if task_data. is_empty ( ) {
430+ return ;
431+ }
432+
412433 // Expedite the next register that is closest to execution, but don't
413434 // move it closer than `INTERVAL_INFO_CHANGE` from the previous
414435 // register.
415436 let minimum_next_register_idx = task_data
416437 . iter ( )
417- . enumerate ( )
418- . filter_map ( |( i, d) | d. next_register . map ( |n| ( i, n) ) )
438+ . filter_map ( |& ( i, ref d) | d. next_register . map ( |n| ( i, n) ) )
419439 . min_by_key ( |& ( _, n) | n)
420440 . map ( |( i, _) | i)
421- . unwrap_or ( 0 ) ;
422- let maximum_prev_register = task_data. iter ( ) . map ( |d| d. prev_register ) . max ( ) . unwrap ( ) ;
423- task_data[ minimum_next_register_idx] . set_next_register (
441+ . unwrap_or ( task_data. first ( ) . unwrap ( ) . 0 ) ;
442+ let maximum_prev_register = task_data
443+ . iter ( )
444+ . map ( |( _, d) | d. prev_register )
445+ . max ( )
446+ . unwrap ( ) ;
447+ task_data[ minimum_next_register_idx] . 1 . set_next_register (
424448 maximum_prev_register + INTERVAL_INFO_CHANGE ,
425- & self . tasks [ minimum_next_register_idx] . next_register_changed ,
449+ & self . tasks [ minimum_next_register_idx]
450+ . as_ref ( )
451+ . unwrap ( )
452+ . next_register_changed ,
426453 ) ;
427454 }
428455 pub fn on_udp_packet ( & self , data : & [ u8 ] ) {
@@ -433,34 +460,37 @@ impl Register {
433460 . read_string ( )
434461 . ok ( )
435462 . and_then ( |s| str:: from_utf8 ( s) . ok ( ) )
436- . and_then ( |s| IpVersion :: from_str ( s) . ok ( ) ) ,
463+ . and_then ( |s| Protocol :: from_str ( s) . ok ( ) ) ,
437464 unpacker
438465 . read_string ( )
439466 . ok ( )
440467 . and_then ( |s| str:: from_utf8 ( s) . ok ( ) ) ,
441468 ) {
442- ( Some ( ip_version ) , Some ( token) ) => self . on_token ( ip_version , token) ,
469+ ( Some ( protocol ) , Some ( token) ) => self . on_token ( protocol , token) ,
443470 _ => error ! ( "invalid challenge packet from mastersrv" ) ,
444471 }
445472 }
446473 }
447- fn on_token ( & self , ip_version : IpVersion , token : & str ) {
448- debug ! ( "{ip_version} challenge_token={token:?}" ) ;
449- let task = & self . tasks [ ip_version. index ( ) ] ;
450- let mut task_data = task. data . lock ( ) . unwrap ( ) ;
451- if Some ( token) != task_data. token . as_deref ( ) {
452- task_data. token = Some ( String :: from ( token) . into_boxed_str ( ) . into ( ) ) ;
453- if let Some ( RegisterResult :: NeedChallenge ) = task_data. prev_result {
454- task_data. set_wait_time ( INTERVAL_TOKEN_REQUIRED , & task. next_register_changed ) ;
474+ fn on_token ( & self , protocol : Protocol , token : & str ) {
475+ debug ! ( "{protocol} challenge_token={token:?}" ) ;
476+ if let Some ( task) = & self . tasks [ protocol. index ( ) ] {
477+ let mut task_data = task. data . lock ( ) . unwrap ( ) ;
478+ if Some ( token) != task_data. token . as_deref ( ) {
479+ task_data. token = Some ( String :: from ( token) . into_boxed_str ( ) . into ( ) ) ;
480+ if let Some ( RegisterResult :: NeedChallenge ) = task_data. prev_result {
481+ task_data. set_wait_time ( INTERVAL_TOKEN_REQUIRED , & task. next_register_changed ) ;
482+ }
455483 }
456484 }
457485 }
458486 pub fn on_heartbeat ( & self ) {
459487 for task in & self . tasks {
460- task. data
461- . lock ( )
462- . unwrap ( )
463- . set_wait_time ( INTERVAL_HEARTBEAT , & task. next_register_changed ) ;
488+ if let Some ( task) = task {
489+ task. data
490+ . lock ( )
491+ . unwrap ( )
492+ . set_wait_time ( INTERVAL_HEARTBEAT , & task. next_register_changed ) ;
493+ }
464494 }
465495 }
466496}
0 commit comments