import copy
from enum import IntEnum
from typing import Optional, Tuple
import torch
from torch import nn
from axelrod.action import Action
from axelrod.load_data_ import load_attention_model_weights
from axelrod.player import Player
C, D = Action.C, Action.D
MEMORY_LENGTH = 200
CLS_TOKEN = 0
PAD_TOKEN = 1
DEVICES = torch.device("cpu")
[docs]
class GameState(IntEnum):
CooperateDefect = 2
DefectCooperate = 3
CooperateCooperate = 4
DefectDefect = 5
def actions_to_game_state(
player_action: Action, opponent_action: Action
) -> GameState:
action_mapping = {
(C, D): GameState.CooperateDefect,
(D, C): GameState.DefectCooperate,
(C, C): GameState.CooperateCooperate,
(D, D): GameState.DefectDefect,
}
return action_mapping[(player_action, opponent_action)]
def compute_features(
player: Player, opponent: Player, right_pad: bool = False
) -> torch.IntTensor:
# The first token is the CLS token
player_history = player.history[-MEMORY_LENGTH:]
player_history = player_history[::-1]
opponent_history = opponent.history[-MEMORY_LENGTH:]
opponent_history = opponent_history[::-1]
feature_size = MEMORY_LENGTH + 1 if right_pad else len(player_history) + 1
game_history = torch.full((feature_size,), PAD_TOKEN, dtype=torch.int)
game_history[0] = CLS_TOKEN
for index, (action_player, action_opponent) in enumerate(
zip(player_history, opponent_history)
):
game_state = actions_to_game_state(action_player, action_opponent)
game_history[index + 1] = game_state
return game_history
[docs]
class GELUActivation(nn.Module):
def __init__(self):
super().__init__()
[docs]
def forward(self, input):
return nn.functional.gelu(input)
class PlayerConfig:
def __init__(
self,
state_size=6, # Number of possible game states, 4 possible game states and 2 specials token
hidden_size=256,
num_hidden_layers=24,
num_attention_heads=8,
intermediate_size=512,
hidden_dropout_prob=0.3,
attention_probs_dropout_prob=0.3,
max_game_size=MEMORY_LENGTH + 1, # Add 1 for the CLS token
initializer_range=0.02,
layer_norm_eps=1e-12,
):
self.state_size = state_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_game_size = max_game_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
[docs]
class PlayerEmbeddings(nn.Module):
"""Construct the embeddings from game state and position embeddings."""
def __init__(self, config: PlayerConfig):
super().__init__()
self.game_state_embeddings = nn.Embedding(
config.state_size, config.hidden_size
)
self.position_embeddings = nn.Embedding(
config.max_game_size, config.hidden_size
)
self.LayerNorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.register_buffer(
"position_ids",
torch.arange(config.max_game_size).expand((1, -1)),
persistent=False,
)
[docs]
def forward(
self,
input_ids: torch.LongTensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
input_shape = input_ids.size()
seq_length = input_shape[1]
position_ids = self.position_ids[:, 0:seq_length]
embeddings = self.game_state_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
attention_mask = (input_ids != PAD_TOKEN).long()
return embeddings, attention_mask
[docs]
class PlayerSelfAttention(nn.Module):
def __init__(self, config: PlayerConfig):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(
config.hidden_size / config.num_attention_heads
)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout_prob = config.attention_probs_dropout_prob
def _transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = src_len
expanded_mask = (
mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
[docs]
def forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
bsz, tgt_len, _ = hidden_states.size()
query_layer = self._transpose_for_scores(self.query(hidden_states))
key_layer = self._transpose_for_scores(self.key(hidden_states))
value_layer = self._transpose_for_scores(self.value(hidden_states))
attn_mask = self._expand_mask(attention_mask, query_layer.dtype)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
dropout_p=self.dropout_prob if self.training else 0.0,
attn_mask=attn_mask,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
return attn_output
[docs]
class PlayerSelfOutput(nn.Module):
def __init__(self, config: PlayerConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
[docs]
def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
[docs]
class PlayerAttention(nn.Module):
def __init__(self, config: PlayerConfig):
super().__init__()
self.self = PlayerSelfAttention(config)
self.output = PlayerSelfOutput(config)
[docs]
def forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
self_outputs = self.self(hidden_states, attention_mask)
attention_output = self.output(self_outputs, hidden_states)
return attention_output
[docs]
class PlayerOutput(nn.Module):
def __init__(self, config: PlayerConfig):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
[docs]
def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
[docs]
class PlayerLayer(nn.Module):
def __init__(self, config: PlayerConfig):
super().__init__()
self.seq_len_dim = 1
self.attention = PlayerAttention(config)
self.intermediate = PlayerIntermediate(config)
self.output = PlayerOutput(config)
[docs]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
[docs]
class PlayerEncoder(nn.Module):
def __init__(self, config: PlayerConfig):
super().__init__()
self.layer = nn.ModuleList(
[PlayerLayer(config) for _ in range(config.num_hidden_layers)]
)
[docs]
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask)
return hidden_states
[docs]
class PlayerPooler(nn.Module):
def __init__(self, config: PlayerConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
[docs]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
[docs]
class PlayerModel(nn.Module):
_no_split_modules = ["PlayerEmbeddings"]
def __init__(self, config: PlayerConfig):
super().__init__()
self.config = config
self.embeddings = PlayerEmbeddings(config)
self.encoder = PlayerEncoder(config)
self.pooler = PlayerPooler(config)
self.action = nn.Linear(config.hidden_size, 1)
[docs]
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding_output, attention_mask = self.embeddings(input_ids=input_ids)
sequence_output = self.encoder(embedding_output, attention_mask)
pooled_output = self.pooler(sequence_output)
return self.action(pooled_output)
def __eq__(self, other: object) -> bool:
return isinstance(other, PlayerModel)
[docs]
class EvolvedAttention(Player):
"""A player who uses an attention mechanism to analyse the game. Trained with self-play.
Names:
- EvolvedAttention: EvolvedAttention by Marc-Olivier Derouin
"""
name = "EvolvedAttention"
classifier = {
"memory_depth": MEMORY_LENGTH,
"stochastic": False,
"long_run_time": True,
"inspects_source": False,
"manipulates_source": False,
"manipulates_state": False,
}
def __init__(
self,
) -> None:
super().__init__()
self.model: Optional[PlayerModel] = None
[docs]
def load_model(self) -> None:
"""Load the model weights."""
if self.model is None:
self.model = PlayerModel(PlayerConfig())
self.model.load_state_dict(load_attention_model_weights())
self.model.to(DEVICES)
self.model.eval()
[docs]
def strategy(self, opponent: Player) -> Action:
"""Actual strategy definition that determines player's action."""
# Load the model if not already loaded
self.load_model()
assert self.model is not None, "Model must be loaded before playing."
# Compute features
features = compute_features(self, opponent).unsqueeze(0).to(DEVICES)
# Get action from the model
logits = self.model(features)
# Apply sigmoid
logits = torch.sigmoid(logits)
return C if logits.item() < 0.5 else D