-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
21 lines (17 loc) · 721 Bytes
/
utils.py
File metadata and controls
21 lines (17 loc) · 721 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def preprocess(obs, env):
"""Performs necessary observation preprocessing."""
if env in ['CartPole-v1']:
return torch.tensor(obs, device=device).float()
elif env in ['MountainCar-v0']:
return torch.tensor(obs, device=device).float()
elif env in ['Pong-v5', 'Breakout-v5']:
obs = np.array(obs)
return torch.tensor(obs, device=device).float()
else:
raise ValueError('Please add necessary observation preprocessing instructions to preprocess() in utils.py.')
# def grayscale(image):
# """Converts an image to gray scale"""
# return np.mean(image, 2)