forked from pkuzhf/storage-gym
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathenv_net.py
48 lines (37 loc) · 1.36 KB
/
env_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import config, utils
from keras import backend as K
from keras.models import Sequential, Input, Model
from keras.layers import Dense, Dropout, Flatten, Reshape, LeakyReLU, PReLU, merge, Activation
from keras.layers import Conv2D, MaxPooling2D, BatchNormalization, GlobalAveragePooling2D
from keras.activations import relu
from keras.initializers import *
def get_env_net():
n = config.Map.Height
m = config.Map.Width
d = utils.Cell.CellSize
use_bn =True
observation = Input(shape=(1, m, n, d), name='observation_input')
x = Reshape((m, n, d))(observation)
#list = [64, 64, 64, 64, 64]
list = [32, 32, 32]
for curdim in list:
x = Conv2D(filters=curdim, kernel_size=(3, 3), padding='same')(x)
if use_bn:
x = BatchNormalization()(x)
x = Activation(activation='relu')(x)
#x = Conv2D(filters=curdim, kernel_size=(1, 1), padding='same')(x)
#if use_bn:
# x = BatchNormalization()(x)
#x = Activation(activation='relu')(x)
x = Flatten()(x)
#x = Dropout(0.2)(x)
#x = Dense(512)(x)
#if use_bn:
# x = BatchNormalization()(x)
#x = Activation(activation='relu')(x)
#x = Dropout(0.1)(x)
actions = Dense(m*n+1)(x)
env_model = Model(inputs=observation, outputs=actions, name='env')
print('env model:')
print(env_model.summary())
return env_model