class ContextBandit:
def __init__(self, arms = 10):
self.arms = arms
self.init_distribution(arms)
self.update_state()
def init_distribution(self, arms):
self.bandit_matrix = np.random.randn(arms, arms)
def reward(self, prob):
reward = 0
for i in range(self.arms):
if random.random() < prob:
reward += 1
return reward
def get_state(self):
return self.state
def update_state(self):
self.state = np.random.randint(0, self.arms)
def get_reward(self, arm):
return self.reward(self.bandit_matrix[self.get_state()][arm])
def choose_arm(self, arm):
reward = self.get_reward(arm)
self.update_state()
return reward
출력층의 활성화 함수는 ReLu다.