Skip to content

Commit f5d824e

Browse files
committed
try improving train.rs experiment
1 parent 4314231 commit f5d824e

1 file changed

Lines changed: 20 additions & 13 deletions

File tree

examples/acrobot-qtable/src/bin/train.rs

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ fn get_reward(task: &AcrobotBalanceTask, state: &AcrobotState, action: &AcrobotA
1111

1212
let pend_pos_center = task.n_pendulum_digitization as f64 / 2.0;
1313
let arm_pos_center = task.n_arm_digitization as f64 / 2.0;
14-
let pend_vel_center = task.n_pendulum_digitization as f64 / 2.0;
14+
let pend_vel_center = task.n_pendulum_digitization as f64 / 2.0;
1515
let arm_vel_center = task.n_arm_digitization as f64 / 2.0;
1616

17-
if (state.n_pendulum_rad as f64 - pend_pos_center).abs() < 1.0 &&
18-
(state.n_arm_rad as f64 - arm_pos_center).abs() < 1.0 &&
19-
(state.n_pendulum_vel as f64 - pend_vel_center).abs() < 1.0 &&
20-
(state.n_arm_vel as f64 - arm_vel_center).abs() < 1.0
17+
if (state.n_pendulum_rad as f64 - pend_pos_center).abs() < 1.0
18+
&& (state.n_arm_rad as f64 - arm_pos_center).abs() < 1.0
19+
&& (state.n_pendulum_vel as f64 - pend_vel_center).abs() < 1.0
20+
&& (state.n_arm_vel as f64 - arm_vel_center).abs() < 1.0
2121
{
2222
return 500.0;
2323
}
@@ -68,14 +68,14 @@ fn main() {
6868
n_pendulum_digitization: usize = @"N_PENDULUM_DIGITIZATION" || 16;
6969
max_episodes: usize = @"MAX_EPISODES" || 1000000;
7070
episode_length: usize = @"EPISODE_LENGTH" || 5000;
71-
model_log_interval: usize = @"MODEL_LOG_INTERVAL" || 2000;
71+
model_save_interval: usize = @"MODEL_LOG_INTERVAL" || 1000;
7272
model_restore_file: std::path::PathBuf = @"MODEL_RESTORE_FILE";
73-
model_log_directory: std::path::PathBuf = @"MODEL_LOG_DIRECTORY" || std::env::current_dir().unwrap()
73+
model_save_directory: std::path::PathBuf = @"MODEL_LOG_DIRECTORY" || std::env::current_dir().unwrap()
7474
.join("models")
7575
.join(chrono::Local::now().format("%m-%d-%H-%M-%S").to_string());
7676
}
7777

78-
std::fs::create_dir_all(&model_log_directory).expect("Failed to create model log directory");
78+
std::fs::create_dir_all(&model_save_directory).expect("Failed to create model log directory");
7979

8080
let mut env = oxide_control::Environment::new(
8181
Acrobot::new(),
@@ -130,7 +130,7 @@ fn main() {
130130
// main training loop
131131
for episode in 1..=max_episodes {
132132
/* episode */
133-
let start_time = std::time::Instant::now();
133+
let start_time = env.physics().time(); //std::time::Instant::now();
134134
let mut episode_reward = 0.0;
135135
let mut obs = env.reset();
136136

@@ -157,11 +157,18 @@ fn main() {
157157
}
158158
}
159159

160-
if episode > 0 && episode % model_log_interval == 0 {
161-
println!("[episode {episode}]: return: {:.2}, time: {:?}", episode_reward, start_time.elapsed());
160+
if episode % 100 == 0 {
161+
println!(
162+
"[episode {episode}]: return: {:.2}, time: {:.2}",
163+
episode_reward,
164+
env.physics().time() - start_time
165+
);
166+
}
167+
if episode % model_save_interval == 0 {
168+
println!("[epidoe {episode}]: saveing agent as file");
162169
agent
163-
.save(model_log_directory.join(format!("agent-{episode}.json")))
164-
.expect("Failed to save Q-table");
170+
.save(model_save_directory.join(format!("agent-{episode}.json")))
171+
.expect("Failed to save agent");
165172
}
166173

167174
agent.decay_alpha_with_rate(0.9999);

0 commit comments

Comments
 (0)