-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdo_trainer.py
168 lines (152 loc) · 9.27 KB
/
do_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# -*- encoding: utf-8 -*-
import sys
import os
import argparse
from common import *
def parse_args():
parser = argparse.ArgumentParser('W2VVPP training script.')
parser.add_argument('--rootpath', type=str, default=ROOT_PATH,
help='path to datasets. (default: %s)' % ROOT_PATH)
parser.add_argument('trainCollection', type=str, default='msrvtt10k',
help='train collection')
parser.add_argument('valCollection', type=str, default='tv2016train',
help='validation collection')
parser.add_argument('--trainCollection2', type=str, default='None',
help='train collection')
parser.add_argument('--task2_caption', type=str, default='no_task2_caption',
help='the suffix of task2 caption.(It looks like "caption.nouns vocab_nouns") Default is nouns.')
parser.add_argument('--train_strategy', type=str, default='usual',
help='train strategy.("usual, subset") Default is usual.')
parser.add_argument('--overwrite', type=int, default=0, choices=[0, 1],
help='overwrite existed vocabulary file. (default: 0)')
parser.add_argument('--val_set', type=str, default='setA',
help='validation collection set (no, setA, setB). (default: setA)')
parser.add_argument('--metric', type=str, default='mir', choices=['r1', 'r5', 'medr', 'meanr', 'mir'],
help='performance metric on validation set')
parser.add_argument('--num_epochs', default=10, type=int,
help='Number of training epochs.')
parser.add_argument('--batch_size', default=128, type=int,
help='Size of a training mini-batch.')
parser.add_argument('--workers', default=16, type=int,
help='Number of data loader workers.')
parser.add_argument('--model_prefix', default='runs_0', type=str,
help='Path to save the model and Tensorboard log.')
parser.add_argument('--config_name', type=str, default='w2vvpp_resnext101-resnet152_subspace',
help='model configuration file. (default: w2vvpp_resnext101-resnet152_subspace')
# parser.add_argument('--model_name', type=str, default='abandoned, Please refer to config.model_name',
# help='The param was abandoned')
parser.add_argument('--parm_adjust_config', type=str, default='None',
help='the config parm you need to set. (default: None')
parser.add_argument("--device", default='2,3', type=str, help="cuda:n or cpu (default: 0)")
parser.add_argument('--random_seed', default=2, type=int,
help='random_seed of the trainer')
parser.add_argument('--local_rank', default=0, type=int,
help='distributed rank if use muti-gpu')
parser.add_argument('--pretrained_file_path', default='None', type=str,
help='Whether use previous model to train')
parser.add_argument('--save_mean_last', default=0, type=int, choices=[0, 1],
help='Whether save the average of last 10 epoch model')
parser.add_argument('--task3_caption', type=str, default='no_task3_caption',
help='the suffix of task3 caption.(It looks like "caption.false ") Default is false.')
args = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.device
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
print(torch.cuda.device_count())
return args
if __name__ == '__main__':
if len(sys.argv) == 1:
print()
# from model.model import get_model, get_we
# gcc
# sys.argv = "trainer.py --device 4 gcc11train gcc11val " \
# "--rootpath /home/liupengju/hf_code/VisualSearch --batch_size 256 " \
# "--workers 16 " \
# "--train_strategy usual " \
# "--model_name w2vpp_mutivis_attention " \
# "--config w2vvpp_mutiVisual_subspace_AvsDataset_AdjustCLIP " \
# "--parm_adjust_config 0_9_1_9 " \
# "--val_set no " \
# "--model_prefix bow_w2v_runs_test1 --overwrite 1".split(' ')
# msrvtt with task2
# sys.argv = "trainer.py --device 3 msrvtt10ktrain msrvtt10kval " \
# "--rootpath /home/liupengju/hf_code/VisualSearch --batch_size 256 " \
# "--workers 10 --task2_caption vg_confidence_thresholdSumRank_freqBt5 " \
# "--train_strategy usual " \
# "--model_name w2vvpp_wv2object " \
# "--config_name w2vvpp_resnext101resnet152_subspace_addPritrainedObject_bow_adjust_alpha " \
# "--parm_adjust_config 0 " \
# "--val_set no " \
# "--pretrained_file_path /data/liupengju/hf/gcc11train/w2vvpp_train/gcc11train/w2vvpp_resnext101resnet152_subspace_addPritrainedObject_bow_adjust_alpha/runs_bow_w2vvpp_wv2object_a_0.5/model_best.pth.tar " \
# "--model_prefix bow_w2v_runs_test1 --overwrite 1".split(' ')
# msrvtt w2vvpp_attention
# sys.argv = "trainer.py --device 0 msrvtt10ktrain msrvtt10kval " \
# "--rootpath /home/liupengju/hf_code/VisualSearch --batch_size 64 " \
# "--workers 10 --task2_caption vg_confidence_thresholdSumRank_freqBt5 " \
# "--train_strategy usual " \
# "--model_name w2vpp_mutivis_attention " \
# "--config_name w2vvpp_mutiVisual_subspace_AdjustGlobalWeight " \
# "--parm_adjust_config 5_9_2_0.95_9_1 " \
# "--val_set no " \
# "--pretrained_file_path None " \
# "--model_prefix bow_w2v_runs_test1 --overwrite 1".split(' ')
# msrvtt w2vpp_MutiVisFrameFeat_attention
# sys.argv = "trainer.py --device 0 msrvtt10ktrain msrvtt10kval " \
# "--rootpath /home/liupengju/hf_code/VisualSearch --batch_size 256 " \
# "--workers 10 " \
# "--train_strategy usual " \
# "--model_name w2vpp_MutiVisFrameFeat_attention " \
# "--config_name w2vvpp_mutiVisual_subspace_AdjustVisframeEncoder " \
# "--parm_adjust_config 0_4_1 " \
# "--val_set no " \
# "--pretrained_file_path None " \
# "--model_prefix bow_w2v_runs_test1 --overwrite 1".split(' ')
# msrvtt sea
# sys.argv = "trainer.py --device 4 msrvtt10ktrain msrvtt10kval " \
# "--rootpath /home/liupengju/hf_code/VisualSearch --batch_size 256 " \
# "--workers 10 " \
# "--train_strategy usual " \
# "--config_name AAAI.sea_adjustVisTxt " \
# "--parm_adjust_config 1_8_6_8 " \
# "--val_set no " \
# "--pretrained_file_path None " \
# "--model_prefix bow_w2v_runs_test1 --overwrite 1".split(' ')
# tgif w2vvpp_attention
# sys.argv = "trainer.py --device 1 tgif-msrvtt10k tv2016train " \
# "--rootpath /home/liupengju/hf_code/VisualSearch --batch_size 128 " \
# "--workers 10 --task2_caption nouns " \
# "--train_strategy usual " \
# "--config_name AAAI.sea_adjustVisTxt " \
# "--parm_adjust_config 1_8_0_8 " \
# "--val_set setA " \
# "--model_prefix test1 --overwrite 1".split(' ')
# tgif-vatex sea
# sys.argv = "trainer.py --device 0 tgif-msrvtt10k-vatex tv2016train " \
# "--rootpath /home/liupengju/hf_code/VisualSearch --batch_size 64 " \
# "--workers 10 " \
# "--train_strategy usual " \
# "--config_name experiments.sea_only_visual_multi_head_avs_adjustVisTxt " \
# "--parm_adjust_config 2_11_0_1 " \
# "--val_set setA " \
# "--model_prefix test1 --overwrite 1".split(' ')
# tgif-vatex sea_multi_head
# sys.argv = "trainer.py --device 0 tgif-msrvtt10k-vatex tv2016train " \
# "--rootpath /home/liupengju/hf_code/VisualSearch --batch_size 64 " \
# "--workers 10 " \
# "--train_strategy usual " \
# "--config_name experiments.sea_only_visual_multi_head_add_concat_avs_adjustVisTxt " \
# "--parm_adjust_config 3_11_1_1 " \
# "--val_set setA " \
# "--model_prefix test1 --overwrite 1".split(' ')
# coco sea_multi_head
# sys.argv = "trainer.py --device 0 tgif-msrvtt10k-vatex tv2016train " \
# "--rootpath /home/liupengju/hf_code/VisualSearch --batch_size 64 " \
# "--workers 10 " \
# "--train_strategy usual " \
# "--config_name experiments.sea_plus_avs_multi_head_adjustVisTxt " \
# "--parm_adjust_config 11_12_1_12 " \
# "--val_set setA " \
# "--model_prefix test1 --overwrite 1".split(' ')
opt = parse_args()
from trainer import parse_args, main
main(opt)