给newbing debug看的
下面是dqn_wrappers的代码:
"""
Adapted from OpenAI Baselines
https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
"""
from collections import deque
import numpy as np
import gym
import copy
import CV2
CV2.ocl.setUseOpenCL(False)
def make_env(env, stack_frames=True, episodic_life=True, clip_rewards=False, scale=False):
if episodic_life:
env = EpisodicLifeEnv(env)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
if stack_frames:
env = FrameStack(env, 4)
if clip_rewards:
env = ClipRewardEnv(env)
return env
class RewardScaler(gym.RewardWrapper):
def reward(self, reward):
return reward * 0.1
class ClipRewardEnv(gym.RewardWrapper):
def __init__(self, env):
gym.RewardWrapper.__init__(self, env)
def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return np.sign(reward)
class LazyFrames(object):
def __init__(self, frames):
"""This object ensures that common frames between the observations are only stored once.
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
buffers.
This object should only be converted to numpy array before being passed to the model.
You'd not believe how complex the previous solution was."""
self._frames = frames
self._out = None
def _force(self):
if self._out is None:
self._out = np.concatenate(self._frames, axis=2)
self._frames = None
return self._out
def __array__(self, dtype=None):
out = self._force()
if dtype is not None:
out = out.astype(dtype)
return out
def __len__(self):
return len(self._force())
def __getitem__(self, i):
return self._force()[i]
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Stack k last frames.
Returns lazy array, which is much more memory efficient.
See Also
--------
baselines.common.atari_wrappers.LazyFrames
"""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
def reset(self):
ob = self.env.reset()[0]
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob()
def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info
def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames))
class WarpFrame(gym.ObservationWrapper):
def __init__(self, env):
"""Warp frames to 84x84 as done in the Nature paper and later work."""
gym.ObservationWrapper.__init__(self, env)
self.width = 84
self.height = 84
self.observation_space = gym.spaces.Box(low=0, high=255,
shape=(self.height, self.width, 1), dtype=np.uint8)
def observation(self, frame):
#print(frame.shape)
frame = frame.astype(np.uint8)
frame = CV2.cvtColor(frame, CV2.COLOR_GRAY2BGR)
#frame = CV2.cvtColor(frame, CV2.COLOR_RGB2GRAY)
frame = CV2.resize(frame, (self.width, self.height), interpolation=CV2.INTER_AREA)
return frame[:, :, None]
class FireResetEnv(gym.Wrapper):
def __init__(self, env=None):
"""For environments where the user need to press FIRE for the game to start."""
super(FireResetEnv, self).__init__(env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def step(self, action):
return self.env.step(action)
def reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env=None):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
super(EpisodicLifeEnv, self).__init__(env)
self.lives = 0
self.was_real_done = True
self.was_real_reset = False
def step(self, action):
obs, reward, done, info,_ = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info
def reset(self):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset()
self.was_real_reset = True
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.was_real_reset = False
self.lives = self.env.unwrapped.ale.lives()
return obs
class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env=None, skip=4):
"""Return only every `skip`-th frame"""
super(MaxAndSkipEnv, self).__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = deque(maxlen=2)
self._skip = skip
def step(self, action):
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
return max_frame, total_reward, done, info
def reset(self):
"""Clear past frame buffer and init. to first obs. from inner env."""
self._obs_buffer.clear()
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
class NoopResetEnv(gym.Wrapper):
def __init__(self, env=None, noop_max=30):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
super(NoopResetEnv, self).__init__(env)
self.noop_max = noop_max
self.override_num_noops = None
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
def step(self, action):
return self.env.step(action)
def reset(self):
""" Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset()
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = np.random.randint(1, self.noop_max + 1)
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(0)
if done:
obs = self.env.reset()
return obs
然后是正文
# %load fig4b_5ab
import gym
import torch
import numpy as np
import copy
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
from dqn_wrappers import *
def get_state(obs):
state = np.array(obs)
state = state.transpose((2,0,1))
state = torch.from_numpy(state)
return state.unsqueeze(0).float()
def select_action(state, policy, eps=0.0):
sample = random.random()
if sample > eps:
with torch.no_grad():
return policy(state.to(device)).max(1)[1].view(1,1).item()
else:
return env.action_space.sample()
def TD0(n_episodes, gamma, alpha0, eta, B, phi0, eps=0.0, cut_factor = 1):
CIs_Q = np.zeros([n_episodes,2])
CIs_SE = np.zeros([n_episodes,2])
Values = np.zeros(n_episodes)
Theta = np.zeros([B+1,p])
Thetabar = np.zeros([B+1,p])
cut = n_episodes // cut_factor
j=0
for i in tqdm(range(n_episodes)):
done = False
state = get_state(env.reset())
while not done:
j += 1
alpha_t = alpha0 * j**(-eta)
W = np.concatenate([[1], np.random.exponential(size=B)])
action = select_action(state, policy0, eps=eps)
obs_, reward, done, info = env.step(action)
state_ = get_state(obs_)
phi = model(state.to(device)).detach().to('cpu').numpy().squeeze()
phi_ = model(state_.to(device)).detach().to('cpu').numpy().squeeze()
At = np.outer(phi, (phi - gamma*phi_))
bt = reward*phi
Theta += alpha_t*np.diag(W) @ (bt[np.newaxis,:] - Theta@At.T)
if j > cut:
Thetabar = ((j-cut-1)*Thetabar + Theta)/(j-cut)
else:
Thetabar = Theta
state = state_
values = Thetabar @ phi0
value = values[0]
valuesW = values[1:]
Q = value + np.quantile(value - valuesW, [0.025, 0.975], axis=0)
SE = float(np.sqrt(np.cov(value - valuesW)))
SE = value + 1.96*SE*np.array([-1,1])
Values[i] = value
CIs_Q[i,:] = Q
CIs_SE[i,:] = SE
return {'value': Values, 'Q': CIs_Q, 'SE': CIs_SE}
if __name__ == "__main__":
device = torch.device("cpu")
policy0 = torch.load('dqn_pong_model', map_location=torch.device('cpu'))
env = gym.make("PongNoFrameskip-v4")
env.reset()
env = make_env(env)
nactions = env.action_space.n
model = copy.deepcopy(policy0)
model.head = torch.nn.Identity()
gamma = 0.99
p = 512
n_episodes = 100
alpha0 = 0.5
eta = 2/3
eps=0.0
state = get_state(env.reset())
phi0 = model(state.to(device)).detach().to('cpu').numpy().squeeze()
CIs = TD0(n_episodes, gamma,alpha0,eta,200,phi0)
values = CIs['value']
CIs_Q = CIs['Q']
CIs_SE = CIs['SE']
# fig 5b
start=10
plt.rcParams['figure.figsize'] = (16,9)
plt.rcParams['figure.facecolor'] = 'white'
plt.plot(np.arange(n_episodes)[start:], values[start:], color='black', label='value estimate', linewidth=3)
plt.plot(np.arange(n_episodes)[start:], CIs_Q[start:], color = 'blue', label='Q', linewidth=3)
plt.plot(np.arange(n_episodes)[start:], CIs_SE[start:], color = 'red', label='SE', linewidth=3)
handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys(), fontsize='xx-large')
plt.xlabel('Number of episodes', fontsize='xx-large')
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.savefig('atari_CI_example.png', bbox_inches='tight')
plt.close()
# fig 6a
CIs_Q_widths = CIs_Q[:,1] - CIs_Q[:,0]
CIs_SE_widths = CIs_SE[:,1] - CIs_SE[:,0]
plt.rcParams['figure.figsize'] = (16,9)
plt.rcParams['figure.facecolor'] = 'white'
plt.plot(np.arange(n_episodes), CIs_Q_widths, color = 'blue', label='Q', linewidth=3)
plt.plot(np.arange(n_episodes), CIs_SE_widths, color = 'red', label='SE', linewidth=3)
plt.legend(fontsize='xx-large')
plt.xlabel('Number of episodes', fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.savefig('atari_CI_widths.png', bbox_inches='tight')
plt.close()
# fig 6b
eps_seq = np.linspace(0,1,5)
eps_dict = dict()
for eps in eps_seq:
eps_dict[eps] = TD0(n_episodes, gamma,alpha0,eta,200,phi0,eps)
SE = np.array([x['SE'] for x in eps_dict.values()])
Q = np.array([x['Q'] for x in eps_dict.values()])
values = np.array([x['value'] for x in eps_dict.values()])
plt.rcParams['figure.figsize'] = (16,9)
plt.rcParams['figure.facecolor'] = 'white'
plt.errorbar(eps_seq, values, yerr=Q.T, fmt='.k')
plt.xlabel('epsilon', fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.savefig('atari_CI_bars.png', bbox_inches='tight')
plt.close()
然后是报错信息:
---------------------------------------------------------------------------ValueError Traceback (most recent call last) Input In [2], in <cell line: 66>() 85 eta = 2/3 86 eps=0.0---> 89 state = get_state(env.reset()) 91 phi0 = model(state.to(device)).detach().to('cpu').numpy().squeeze() 93 CIs = TD0(n_episodes, gamma,alpha0,eta,200,phi0) Input In [2], in get_state(obs) 11 def get_state(obs): 12 state = np.array(obs)---> 13 state = state.transpose((2,0,1)) 14 state = torch.from_numpy(state) 15 return state.unsqueeze(0).float()ValueError: axes don't match array