| | import random |
| | import numpy as np |
| | from scipy.misc import imresize |
| |
|
| | import random |
| | import numpy as np |
| | from scipy.ndimage import rotate |
| | from scipy.misc import imresize |
| | from matplotlib import pyplot as plt |
| |
|
| | class CatchEnv2: |
| | def __init__(self): |
| | self.size = 21 |
| | self.image = np.zeros((self.size, self.size)) |
| | self.state = [] |
| | self.fps = 4 |
| | self.output_shape = (84, 84) |
| |
|
| | def reset_random(self): |
| | self.image.fill(0) |
| | self.pos = np.random.randint(2, self.size-2) |
| | self.vx = np.random.randint(5) - 2 |
| | self.vy = 1 |
| | self.ballx, self.bally = np.random.randint(self.size), 4 |
| | |
| | self.image[self.bally, self.ballx] = 1 |
| | self.image[-5, self.pos - 2:self.pos + 3] = np.ones(5) |
| | |
| | for i in range(0, self.size): |
| | for j in range(0, self.size): |
| | self.image[i][j] = random.randint(2,5) |
| | |
| | return self.step(2)[0] |
| |
|
| |
|
| | def step(self, action): |
| | def left(): |
| | if self.pos > 3: |
| | self.pos -= 2 |
| | def right(): |
| | if self.pos < 17: |
| | self.pos += 2 |
| | def noop(): |
| | pass |
| | {0: left, 1: right, 2: noop}[action]() |
| |
|
| | |
| | self.image[self.bally, self.ballx] = 0 |
| | self.ballx += self.vx |
| | self.bally += self.vy |
| | if self.ballx > self.size - 1: |
| | self.ballx -= 2 * (self.ballx - (self.size-1)) |
| | self.vx *= -1 |
| | elif self.ballx < 0: |
| | self.ballx += 2 * (0 - self.ballx) |
| | self.vx *= -1 |
| | |
| | self.image[self.bally, self.ballx] = 1 |
| | self.image[-5].fill(random.randint(2,5)) |
| | self.image[-5, self.pos-2:self.pos+3] = np.ones(5) |
| | |
| | terminal = self.bally == self.size - 1 - 4 |
| | reward = int(self.pos - 2 <= self.ballx <= self.pos + 2) if terminal else 0 |
| |
|
| | [self.state.append(imresize(self.image, (84, 84))) for _ in range(self.fps - len(self.state) + 1)] |
| | self.state = self.state[-self.fps:] |
| |
|
| | return np.transpose(self.state, [1, 2, 0]), reward, terminal |
| |
|
| | def get_num_actions(self): |
| | return 3 |
| |
|
| | def reset(self): |
| | return self.reset_random() |
| |
|
| | def state_shape(self): |
| | return (self.fps,) + self.output_shape |
| | |
| | def show_state(self, i): |
| | plt.imshow(self.image) |
| | plt.imsave('image_'+str(i)+'.jpg', self.image) |
| |
|
| | def test(): |
| | env = CatchEnv2() |
| | i = 0 |
| | for ep in range(1): |
| | env.reset() |
| | env.show_state(i) |
| | |
| | state, reward, terminal = env.step(1) |
| | while not terminal: |
| | env.show_state(i) |
| | state, reward, terminal = env.step(np.random.randint(0,2)) |
| | i += 1 |
| | |
| |
|
| | if __name__ == "main": |
| | test() |
| |
|