@@ -11,13 +11,13 @@ fn get_reward(task: &AcrobotBalanceTask, state: &AcrobotState, action: &AcrobotA
1111
1212 let pend_pos_center = task. n_pendulum_digitization as f64 / 2.0 ;
1313 let arm_pos_center = task. n_arm_digitization as f64 / 2.0 ;
14- let pend_vel_center = task. n_pendulum_digitization as f64 / 2.0 ;
14+ let pend_vel_center = task. n_pendulum_digitization as f64 / 2.0 ;
1515 let arm_vel_center = task. n_arm_digitization as f64 / 2.0 ;
1616
17- if ( state. n_pendulum_rad as f64 - pend_pos_center) . abs ( ) < 1.0 &&
18- ( state. n_arm_rad as f64 - arm_pos_center) . abs ( ) < 1.0 &&
19- ( state. n_pendulum_vel as f64 - pend_vel_center) . abs ( ) < 1.0 &&
20- ( state. n_arm_vel as f64 - arm_vel_center) . abs ( ) < 1.0
17+ if ( state. n_pendulum_rad as f64 - pend_pos_center) . abs ( ) < 1.0
18+ && ( state. n_arm_rad as f64 - arm_pos_center) . abs ( ) < 1.0
19+ && ( state. n_pendulum_vel as f64 - pend_vel_center) . abs ( ) < 1.0
20+ && ( state. n_arm_vel as f64 - arm_vel_center) . abs ( ) < 1.0
2121 {
2222 return 500.0 ;
2323 }
@@ -68,14 +68,14 @@ fn main() {
6868 n_pendulum_digitization: usize = @"N_PENDULUM_DIGITIZATION" || 16 ;
6969 max_episodes: usize = @"MAX_EPISODES" || 1000000 ;
7070 episode_length: usize = @"EPISODE_LENGTH" || 5000 ;
71- model_log_interval : usize = @"MODEL_LOG_INTERVAL" || 2000 ;
71+ model_save_interval : usize = @"MODEL_LOG_INTERVAL" || 1000 ;
7272 model_restore_file: std:: path:: PathBuf = @"MODEL_RESTORE_FILE" ;
73- model_log_directory : std:: path:: PathBuf = @"MODEL_LOG_DIRECTORY" || std:: env:: current_dir( ) . unwrap( )
73+ model_save_directory : std:: path:: PathBuf = @"MODEL_LOG_DIRECTORY" || std:: env:: current_dir( ) . unwrap( )
7474 . join( "models" )
7575 . join( chrono:: Local :: now( ) . format( "%m-%d-%H-%M-%S" ) . to_string( ) ) ;
7676 }
7777
78- std:: fs:: create_dir_all ( & model_log_directory ) . expect ( "Failed to create model log directory" ) ;
78+ std:: fs:: create_dir_all ( & model_save_directory ) . expect ( "Failed to create model log directory" ) ;
7979
8080 let mut env = oxide_control:: Environment :: new (
8181 Acrobot :: new ( ) ,
@@ -130,7 +130,7 @@ fn main() {
130130 // main training loop
131131 for episode in 1 ..=max_episodes {
132132 /* episode */
133- let start_time = std:: time:: Instant :: now ( ) ;
133+ let start_time = env . physics ( ) . time ( ) ; // std::time::Instant::now();
134134 let mut episode_reward = 0.0 ;
135135 let mut obs = env. reset ( ) ;
136136
@@ -157,11 +157,18 @@ fn main() {
157157 }
158158 }
159159
160- if episode > 0 && episode % model_log_interval == 0 {
161- println ! ( "[episode {episode}]: return: {:.2}, time: {:?}" , episode_reward, start_time. elapsed( ) ) ;
160+ if episode % 100 == 0 {
161+ println ! (
162+ "[episode {episode}]: return: {:.2}, time: {:.2}" ,
163+ episode_reward,
164+ env. physics( ) . time( ) - start_time
165+ ) ;
166+ }
167+ if episode % model_save_interval == 0 {
168+ println ! ( "[epidoe {episode}]: saveing agent as file" ) ;
162169 agent
163- . save ( model_log_directory . join ( format ! ( "agent-{episode}.json" ) ) )
164- . expect ( "Failed to save Q-table " ) ;
170+ . save ( model_save_directory . join ( format ! ( "agent-{episode}.json" ) ) )
171+ . expect ( "Failed to save agent " ) ;
165172 }
166173
167174 agent. decay_alpha_with_rate ( 0.9999 ) ;
0 commit comments