欢迎光临散文网 会员登陆 & 注册

给newbing debug看的

2023-03-06 17:12 作者:shieldermash  | 我要投稿

下面是dqn_wrappers的代码:

"""

Adapted from OpenAI Baselines

https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py

"""

from collections import deque

import numpy as np

import gym

import copy

import CV2

CV2.ocl.setUseOpenCL(False)


def make_env(env, stack_frames=True, episodic_life=True, clip_rewards=False, scale=False):

    if episodic_life:

        env = EpisodicLifeEnv(env)


    env = NoopResetEnv(env, noop_max=30)

    env = MaxAndSkipEnv(env, skip=4)

    if 'FIRE' in env.unwrapped.get_action_meanings():

        env = FireResetEnv(env)


    env = WarpFrame(env)

    if stack_frames:

        env = FrameStack(env, 4)

    if clip_rewards:

        env = ClipRewardEnv(env)

    return env


class RewardScaler(gym.RewardWrapper):


    def reward(self, reward):

        return reward * 0.1



class ClipRewardEnv(gym.RewardWrapper):

    def __init__(self, env):

        gym.RewardWrapper.__init__(self, env)


    def reward(self, reward):

        """Bin reward to {+1, 0, -1} by its sign."""

        return np.sign(reward)



class LazyFrames(object):

    def __init__(self, frames):

        """This object ensures that common frames between the observations are only stored once.

        It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay

        buffers.

        This object should only be converted to numpy array before being passed to the model.

        You'd not believe how complex the previous solution was."""

        self._frames = frames

        self._out = None


    def _force(self):

        if self._out is None:

            self._out = np.concatenate(self._frames, axis=2)

            self._frames = None

        return self._out


    def __array__(self, dtype=None):

        out = self._force()

        if dtype is not None:

            out = out.astype(dtype)

        return out


    def __len__(self):

        return len(self._force())


    def __getitem__(self, i):

        return self._force()[i]


class FrameStack(gym.Wrapper):

    def __init__(self, env, k):

        """Stack k last frames.

        Returns lazy array, which is much more memory efficient.

        See Also

        --------

        baselines.common.atari_wrappers.LazyFrames

        """

        gym.Wrapper.__init__(self, env)

        self.k = k

        self.frames = deque([], maxlen=k)

        shp = env.observation_space.shape

        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)


    def reset(self):

        ob = self.env.reset()[0]

        for _ in range(self.k):

            self.frames.append(ob)

        return self._get_ob()


    def step(self, action):

        ob, reward, done, info = self.env.step(action)

        self.frames.append(ob)

        return self._get_ob(), reward, done, info


    def _get_ob(self):

        assert len(self.frames) == self.k

        return LazyFrames(list(self.frames))



class WarpFrame(gym.ObservationWrapper):

    def __init__(self, env):

        """Warp frames to 84x84 as done in the Nature paper and later work."""

        gym.ObservationWrapper.__init__(self, env)

        self.width = 84

        self.height = 84

        self.observation_space = gym.spaces.Box(low=0, high=255,

            shape=(self.height, self.width, 1), dtype=np.uint8)


    def observation(self, frame):

        #print(frame.shape)

        frame = frame.astype(np.uint8)

        frame = CV2.cvtColor(frame, CV2.COLOR_GRAY2BGR)

        #frame = CV2.cvtColor(frame, CV2.COLOR_RGB2GRAY)

        frame = CV2.resize(frame, (self.width, self.height), interpolation=CV2.INTER_AREA)

        return frame[:, :, None]



class FireResetEnv(gym.Wrapper):

    def __init__(self, env=None):

        """For environments where the user need to press FIRE for the game to start."""

        super(FireResetEnv, self).__init__(env)

        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'

        assert len(env.unwrapped.get_action_meanings()) >= 3


    def step(self, action):

        return self.env.step(action)


    def reset(self):

        self.env.reset()

        obs, _, done, _ = self.env.step(1)

        if done:

            self.env.reset()

        obs, _, done, _ = self.env.step(2)

        if done:

            self.env.reset()

        return obs



class EpisodicLifeEnv(gym.Wrapper):

    def __init__(self, env=None):

        """Make end-of-life == end-of-episode, but only reset on true game over.

        Done by DeepMind for the DQN and co. since it helps value estimation.

        """

        super(EpisodicLifeEnv, self).__init__(env)

        self.lives = 0

        self.was_real_done = True

        self.was_real_reset = False


    def step(self, action):

        obs, reward, done, info,_ = self.env.step(action)

        self.was_real_done = done

        # check current lives, make loss of life terminal,

        # then update lives to handle bonus lives

        lives = self.env.unwrapped.ale.lives()

        if lives < self.lives and lives > 0:

            # for Qbert somtimes we stay in lives == 0 condtion for a few frames

            # so its important to keep lives > 0, so that we only reset once

            # the environment advertises done.

            done = True

        self.lives = lives

        return obs, reward, done, info


    def reset(self):

        """Reset only when lives are exhausted.

        This way all states are still reachable even though lives are episodic,

        and the learner need not know about any of this behind-the-scenes.

        """

        if self.was_real_done:

            obs = self.env.reset()

            self.was_real_reset = True

        else:

            # no-op step to advance from terminal/lost life state

            obs, _, _, _ = self.env.step(0)

            self.was_real_reset = False

        self.lives = self.env.unwrapped.ale.lives()

        return obs



class MaxAndSkipEnv(gym.Wrapper):

    def __init__(self, env=None, skip=4):

        """Return only every `skip`-th frame"""

        super(MaxAndSkipEnv, self).__init__(env)

        # most recent raw observations (for max pooling across time steps)

        self._obs_buffer = deque(maxlen=2)

        self._skip = skip


    def step(self, action):

        total_reward = 0.0

        done = None

        for _ in range(self._skip):

            obs, reward, done, info = self.env.step(action)

            self._obs_buffer.append(obs)

            total_reward += reward

            if done:

                break


        max_frame = np.max(np.stack(self._obs_buffer), axis=0)


        return max_frame, total_reward, done, info


    def reset(self):

        """Clear past frame buffer and init. to first obs. from inner env."""

        self._obs_buffer.clear()

        obs = self.env.reset()

        self._obs_buffer.append(obs)

        return obs


class NoopResetEnv(gym.Wrapper):

    def __init__(self, env=None, noop_max=30):

        """Sample initial states by taking random number of no-ops on reset.

        No-op is assumed to be action 0.

        """

        super(NoopResetEnv, self).__init__(env)

        self.noop_max = noop_max

        self.override_num_noops = None

        assert env.unwrapped.get_action_meanings()[0] == 'NOOP'


    def step(self, action):

        return self.env.step(action)


    def reset(self):

        """ Do no-op action for a number of steps in [1, noop_max]."""

        self.env.reset()

        if self.override_num_noops is not None:

            noops = self.override_num_noops

        else:

            noops = np.random.randint(1, self.noop_max + 1)

        assert noops > 0

        obs = None

        for _ in range(noops):

            obs, _, done, _ = self.env.step(0)

            if done:

                obs = self.env.reset()

        return obs

然后是正文

# %load fig4b_5ab

import gym

import torch

import numpy as np

import copy

import random

from tqdm import tqdm

import matplotlib.pyplot as plt

from dqn_wrappers import *


def get_state(obs):

    state = np.array(obs)

    state = state.transpose((2,0,1))

    state = torch.from_numpy(state)

    return state.unsqueeze(0).float()


def select_action(state, policy, eps=0.0):

    sample = random.random()

    if sample > eps:

        with torch.no_grad():

            return policy(state.to(device)).max(1)[1].view(1,1).item()

    else:

        return env.action_space.sample()


def TD0(n_episodes, gamma, alpha0, eta, B, phi0, eps=0.0, cut_factor = 1):

    CIs_Q = np.zeros([n_episodes,2])

    CIs_SE = np.zeros([n_episodes,2])

    Values = np.zeros(n_episodes)

    Theta = np.zeros([B+1,p])

    Thetabar = np.zeros([B+1,p])

    cut = n_episodes // cut_factor

    j=0


    for i in tqdm(range(n_episodes)):

        done = False

        state = get_state(env.reset())

        while not done:

            j += 1

            alpha_t = alpha0 * j**(-eta)

            W = np.concatenate([[1], np.random.exponential(size=B)])

            action = select_action(state, policy0, eps=eps)

            obs_, reward, done, info = env.step(action)

            state_ = get_state(obs_)

            phi = model(state.to(device)).detach().to('cpu').numpy().squeeze()

            phi_ = model(state_.to(device)).detach().to('cpu').numpy().squeeze()

            At = np.outer(phi, (phi - gamma*phi_))

            bt = reward*phi

            Theta += alpha_t*np.diag(W) @ (bt[np.newaxis,:] - Theta@At.T)

            if j > cut:

                Thetabar = ((j-cut-1)*Thetabar + Theta)/(j-cut)

            else:

                Thetabar = Theta

            state = state_

        values = Thetabar @ phi0

        value = values[0]

        valuesW = values[1:]

        Q = value + np.quantile(value - valuesW, [0.025, 0.975], axis=0)

        SE = float(np.sqrt(np.cov(value - valuesW)))

        SE = value + 1.96*SE*np.array([-1,1])

        Values[i] = value

        CIs_Q[i,:] = Q

        CIs_SE[i,:] = SE


    return {'value': Values, 'Q': CIs_Q, 'SE': CIs_SE}


if __name__ == "__main__":

    device = torch.device("cpu")


    policy0 = torch.load('dqn_pong_model', map_location=torch.device('cpu'))


    env = gym.make("PongNoFrameskip-v4")

    env.reset()

    env = make_env(env)


  

    nactions = env.action_space.n


    model = copy.deepcopy(policy0)

    model.head = torch.nn.Identity()


    gamma = 0.99

    p = 512

    n_episodes = 100

    alpha0 = 0.5

    eta = 2/3

    eps=0.0

    


    state = get_state(env.reset())


    phi0 = model(state.to(device)).detach().to('cpu').numpy().squeeze()


    CIs = TD0(n_episodes, gamma,alpha0,eta,200,phi0)


    values = CIs['value']

    CIs_Q = CIs['Q']

    CIs_SE = CIs['SE']


    # fig 5b

    start=10

    plt.rcParams['figure.figsize'] = (16,9)

    plt.rcParams['figure.facecolor'] = 'white'

    plt.plot(np.arange(n_episodes)[start:], values[start:], color='black', label='value estimate', linewidth=3)

    plt.plot(np.arange(n_episodes)[start:], CIs_Q[start:], color = 'blue', label='Q', linewidth=3)

    plt.plot(np.arange(n_episodes)[start:], CIs_SE[start:], color = 'red', label='SE', linewidth=3)

    handles, labels = plt.gca().get_legend_handles_labels()

    by_label = dict(zip(labels, handles))

    plt.legend(by_label.values(), by_label.keys(), fontsize='xx-large')

    plt.xlabel('Number of episodes', fontsize='xx-large')

    plt.xticks(fontsize=20)

    plt.yticks(fontsize=20)

    plt.savefig('atari_CI_example.png', bbox_inches='tight')

    plt.close()


    # fig 6a

    CIs_Q_widths = CIs_Q[:,1] - CIs_Q[:,0]

    CIs_SE_widths = CIs_SE[:,1] - CIs_SE[:,0]


    plt.rcParams['figure.figsize'] = (16,9)

    plt.rcParams['figure.facecolor'] = 'white'

    plt.plot(np.arange(n_episodes), CIs_Q_widths, color = 'blue', label='Q', linewidth=3)

    plt.plot(np.arange(n_episodes), CIs_SE_widths, color = 'red', label='SE', linewidth=3)

    plt.legend(fontsize='xx-large')

    plt.xlabel('Number of episodes', fontsize=20)

    plt.xticks(fontsize=20)

    plt.yticks(fontsize=20)

    plt.savefig('atari_CI_widths.png', bbox_inches='tight')

    plt.close()


    # fig 6b

    eps_seq = np.linspace(0,1,5)

    eps_dict = dict()


    for eps in eps_seq:

        eps_dict[eps] = TD0(n_episodes, gamma,alpha0,eta,200,phi0,eps)


    SE = np.array([x['SE'] for x in eps_dict.values()])

    Q = np.array([x['Q'] for x in eps_dict.values()])

    values = np.array([x['value'] for x in eps_dict.values()])


    plt.rcParams['figure.figsize'] = (16,9)

    plt.rcParams['figure.facecolor'] = 'white'

    plt.errorbar(eps_seq, values, yerr=Q.T, fmt='.k')

    plt.xlabel('epsilon', fontsize=20)

    plt.xticks(fontsize=20)

    plt.yticks(fontsize=20)

    plt.savefig('atari_CI_bars.png', bbox_inches='tight')

    plt.close()

然后是报错信息:

---------------------------------------------------------------------------ValueError                                Traceback (most recent call last) Input In [2], in <cell line: 66>()     85 eta = 2/3     86 eps=0.0---> 89 state = get_state(env.reset())     91 phi0 = model(state.to(device)).detach().to('cpu').numpy().squeeze()     93 CIs = TD0(n_episodes, gamma,alpha0,eta,200,phi0) Input In [2], in get_state(obs)     11 def get_state(obs):     12     state = np.array(obs)---> 13     state = state.transpose((2,0,1))     14     state = torch.from_numpy(state)     15     return state.unsqueeze(0).float()ValueError: axes don't match array

给newbing debug看的的评论 (共 条)

分享到微博请遵守国家法律