-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathextract_video_features_clip.py
74 lines (66 loc) · 2.42 KB
/
extract_video_features_clip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import cv2
import numpy as np
from torchvision import transforms
from model.model_clip import build_model
def init_transform_dict(input_res=224,
center_crop=256,
randcrop_scale=(0.5, 1.0),
color_jitter=(0, 0, 0),
norm_mean=(0.48145466, 0.4578275, 0.40821073),
norm_std=(0.26862954, 0.26130258, 0.27577711)):
normalize = transforms.Normalize(mean=norm_mean, std=norm_std)
tsfm_dict = {
'test': transforms.Compose([
transforms.Resize(center_crop),
transforms.CenterCrop(center_crop),
transforms.Resize(input_res),
normalize,
])
}
return tsfm_dict
def sample_frames(num_frames, vlen):
acc_samples = min(num_frames, vlen)
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
ranges = []
for idx, interv in enumerate(intervals[:-1]):
ranges.append((interv, intervals[idx + 1] - 1))
frame_idxs = [(x[0] + x[1]) // 2 for x in ranges]
return frame_idxs
def read_frames_cv2(video_path, num_frames):
cap = cv2.VideoCapture(video_path)
assert (cap.isOpened())
vlen = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# get indexes of sampled frames
frame_idxs = sample_frames(num_frames, vlen)
frames = []
success_idxs = []
for index in frame_idxs:
cap.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = torch.from_numpy(frame)
# (H x W x C) to (C x H x W)
frame = frame.permute(2, 0, 1)
frames.append(frame)
success_idxs.append(index)
else:
pass
frames = torch.stack(frames).float() / 255
cap.release()
return frames, success_idxs
video_transforms = init_transform_dict()['test']
video_path = ''
num_frames = 4
video, idxs = read_frames_cv2(video_path, num_frames)
video = video_transforms(video).unsqueeze(0).cuda()
print(video.shape)
model_path = './MCQ_CLIP.pth'
model_clip = torch.load(model_path, map_location="cpu")
state_dict = model_clip['state_dict']
model = build_model(state_dict)
model = model.cuda()
video_features = model.encode_image(video)
video_features = video_features / video_features.norm(dim=-1, keepdim=True)
print(video_features.shape)