forked from foersterrobert/AlphaZeroFromScratch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmcts.py
More file actions
70 lines (48 loc) · 2.06 KB
/
mcts.py
File metadata and controls
70 lines (48 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
import torch
from node import Node
class MCTS:
def __init__(self, game, args, model):
self.game = game
self.args = args
self.model = model
self.node_depth_counts = {}
@torch.no_grad()
def search(self, state):
# define root node
root = Node(self.game, self.args, state, visit_count=1)
policy, _ = self.model(
torch.tensor(self.game.get_encoded_state(state), device=self.model.device).unsqueeze(0)
)
policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
policy = ((1 - self.args['dirichlet_epsilon']) * policy + self.args['dirichlet_epsilon'] *
np.random.dirichlet([self.args['dirichlet_alpha']] * self.game.action_size))
valid_moves = self.game.get_valid_moves(state)
policy *= valid_moves
policy /= np.sum(policy)
root.expand(policy)
for search in range(self.args['num_searches']):
node = root
# selection
while node.is_fully_expanded():
node = node.select()
value, is_terminal = self.game.get_value_and_terminated(node.state, node.action_taken)
value = self.game.get_opponent_value(value)
if not is_terminal:
policy, value = self.model(
torch.tensor(self.game.get_encoded_state(node.state), device=self.model.device).unsqueeze(0)
)
policy = torch.softmax(policy, axis=1).squeeze(0).cpu().numpy()
valid_moves = self.game.get_valid_moves(node.state)
policy *= valid_moves
policy /= np.sum(policy)
value = value.item()
# expansion
node.expand(policy)
node.backpropagate(value)
# backpropagation
action_probs = np.zeros(self.game.action_size)
for child in root.children:
action_probs[child.action_taken] = child.visit_count
action_probs /= np.sum(action_probs)
return action_probs