Skip to content

Commit b3be914

Browse files
committed
fix: change optical flow calculation & fix video examples
1 parent 1822b22 commit b3be914

File tree

4 files changed

+632
-278
lines changed

4 files changed

+632
-278
lines changed

DPF/filters/videos/farneback_filter.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import io
2-
from typing import Any
2+
from typing import Any, Optional
33

44
import cv2
55
import imageio.v3 as iio
@@ -17,6 +17,21 @@ def transform_frame(frame: MatLike, target_size: tuple[int, int]) -> MatLike:
1717
return frame
1818

1919

20+
def transform_keep_ar(frame: MatLike, min_side_size: int) -> MatLike:
21+
h, w = frame.shape[:2]
22+
aspect_ratio = w / h
23+
if h <= w:
24+
new_height = min_side_size
25+
new_width = int(aspect_ratio * new_height)
26+
else:
27+
new_width = min_side_size
28+
new_height = int(new_width / aspect_ratio)
29+
30+
resized_frame = cv2.resize(frame, dsize=(new_width, new_height), interpolation=cv2.INTER_LINEAR)
31+
resized_frame = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2GRAY)
32+
return resized_frame
33+
34+
2035
class GunnarFarnebackFilter(VideoFilter):
2136
"""
2237
Gunnar-Farneback filter inference class to get mean optical flow each video.
@@ -46,7 +61,9 @@ class GunnarFarnebackFilter(VideoFilter):
4661

4762
def __init__(
4863
self,
49-
pass_frames: int = 10,
64+
pass_frames: int = 12,
65+
num_passes: Optional[int] = None,
66+
min_frame_size: int = 512,
5067
pyramid_scale: float = 0.5,
5168
levels: int = 3,
5269
win_size: int = 15,
@@ -55,14 +72,16 @@ def __init__(
5572
poly_sigma: float = 1.2,
5673
workers: int = 16,
5774
flags: int = 0,
58-
batch_size: int = 1,
5975
pbar: bool = True,
6076
_pbar_position: int = 0
6177
):
6278
super().__init__(pbar, _pbar_position)
6379

6480
self.num_workers = workers
65-
self.batch_size = batch_size
81+
82+
self.num_passes = num_passes
83+
self.min_frame_size = min_frame_size
84+
self.pass_frames = pass_frames
6685

6786
self.pyramid_scale = pyramid_scale
6887
self.levels = levels
@@ -72,17 +91,15 @@ def __init__(
7291
self.poly_sigma = poly_sigma
7392
self.flags = flags
7493

75-
self.pass_frames = pass_frames
76-
7794
@property
7895
def result_columns(self) -> list[str]:
79-
return ["mean_optical_flow_farneback"]
96+
return ["optical_flow_farneback"]
8097

8198
@property
8299
def dataloader_kwargs(self) -> dict[str, Any]:
83100
return {
84101
"num_workers": self.num_workers,
85-
"batch_size": self.batch_size,
102+
"batch_size": 1,
86103
"drop_last": False,
87104
}
88105

@@ -95,27 +112,18 @@ def preprocess_data(
95112
video_file = modality2data['video']
96113

97114
frames = iio.imread(io.BytesIO(video_file), plugin="pyav")
115+
max_frame_to_process = self.num_passes*self.pass_frames if self.num_passes else len(frames)
116+
frames_transformed = []
98117

99-
if frames.shape[1] > frames.shape[2]:
100-
frames_resized = [
101-
transform_frame(frame=frames[i], target_size=(450, 800))
102-
for i in range(self.pass_frames, len(frames), self.pass_frames)
103-
]
104-
elif frames.shape[2] > frames.shape[1]:
105-
frames_resized = [
106-
transform_frame(frame=frames[i], target_size=(800, 450))
107-
for i in range(self.pass_frames, len(frames), self.pass_frames)
108-
]
109-
else:
110-
frames_resized = [
111-
transform_frame(frame=frames[i], target_size=(450, 450))
112-
for i in range(self.pass_frames, len(frames), self.pass_frames)
113-
]
118+
frames_transformed = [
119+
transform_keep_ar(frames[i], self.min_frame_size)
120+
for i in range(self.pass_frames, min(max_frame_to_process+1, len(frames)), self.pass_frames)
121+
]
114122

115123
mean_magnitudes: list[float] = []
116-
for i in range(self.pass_frames, len(frames_resized), self.pass_frames):
117-
current_frame = frames_resized[i - self.pass_frames]
118-
next_frame = frames_resized[i]
124+
for i in range(len(frames_transformed)-1):
125+
current_frame = frames_transformed[i]
126+
next_frame = frames_transformed[i+1]
119127
flow = cv2.calcOpticalFlowFarneback(
120128
current_frame,
121129
next_frame,
@@ -139,5 +147,5 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
139147
for data in batch:
140148
key, mean_optical_flow = data
141149
df_batch_labels[self.key_column].append(key)
142-
df_batch_labels['mean_optical_flow_farneback'].append(round(mean_optical_flow, 3))
150+
df_batch_labels[self.result_columns[0]].append(round(mean_optical_flow, 3))
143151
return df_batch_labels

DPF/filters/videos/raft_filter.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import io
22
import os
3-
from typing import Any
3+
from typing import Any, Optional
44
from urllib.request import urlopen
55
from zipfile import ZipFile
66

@@ -28,6 +28,24 @@ def transform_frame(frame: MatLike, target_size: tuple[int, int]) -> Tensor:
2828
return frame_tensor
2929

3030

31+
def transform_keep_ar(frame: MatLike, min_side_size: int) -> Tensor:
32+
h, w = frame.shape[:2]
33+
aspect_ratio = w / h
34+
if h <= w:
35+
new_height = min_side_size
36+
new_width = int(aspect_ratio * new_height)
37+
else:
38+
new_width = min_side_size
39+
new_height = int(new_width / aspect_ratio)
40+
41+
frame = cv2.resize(frame, dsize=(new_width, new_height), interpolation=cv2.INTER_LINEAR)
42+
frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float()[None]
43+
44+
padder = InputPadder(frame_tensor.shape) # type: ignore
45+
frame_tensor = padder.pad(frame_tensor)[0]
46+
return frame_tensor
47+
48+
3149
class InputPadder:
3250
""" Pads images such that dimensions are divisible by 8 """
3351

@@ -62,20 +80,24 @@ class RAFTOpticalFlowFilter(VideoFilter):
6280
def __init__(
6381
self,
6482
pass_frames: int = 10,
83+
num_passes: Optional[int] = None,
84+
min_frame_size: int = 512,
6585
use_small_model: bool = False,
86+
raft_iters: int = 20,
6687
device: str = "cuda:0",
6788
workers: int = 16,
68-
batch_size: int = 1,
6989
pbar: bool = True,
7090
_pbar_position: int = 0
7191
):
7292
super().__init__(pbar, _pbar_position)
7393
self.num_workers = workers
74-
self.batch_size = batch_size
7594
self.device = device
7695

7796
assert pass_frames >= 1, "Number of pass_frames should be greater or equal to 1."
7897
self.pass_frames = pass_frames
98+
self.num_passes = num_passes
99+
self.min_frame_size = min_frame_size
100+
self.raft_iters = raft_iters
79101

80102
resp = urlopen(WEIGHTS_URL)
81103
zipped_files = ZipFile(io.BytesIO(resp.read()))
@@ -98,13 +120,13 @@ def __init__(
98120

99121
@property
100122
def result_columns(self) -> list[str]:
101-
return [f"mean_optical_flow_{self.model_name}"]
123+
return [f"optical_flow_{self.model_name}"]
102124

103125
@property
104126
def dataloader_kwargs(self) -> dict[str, Any]:
105127
return {
106128
"num_workers": self.num_workers,
107-
"batch_size": self.batch_size,
129+
"batch_size": 1,
108130
"drop_last": False,
109131
}
110132

@@ -117,23 +139,13 @@ def preprocess_data(
117139
video_file = modality2data['video']
118140

119141
frames = iio.imread(io.BytesIO(video_file), plugin="pyav")
120-
121-
if frames.shape[1] > frames.shape[2]:
122-
frames_resized = [
123-
transform_frame(frame=frames[i], target_size=(450, 800))
124-
for i in range(self.pass_frames, len(frames), self.pass_frames)
125-
]
126-
elif frames.shape[2] > frames.shape[1]:
127-
frames_resized = [
128-
transform_frame(frame=frames[i], target_size=(800, 450))
129-
for i in range(self.pass_frames, len(frames), self.pass_frames)
130-
]
131-
else:
132-
frames_resized = [
133-
transform_frame(frame=frames[i], target_size=(450, 450))
134-
for i in range(self.pass_frames, len(frames), self.pass_frames)
135-
]
136-
return key, frames_resized
142+
max_frame_to_process = self.num_passes*self.pass_frames if self.num_passes else len(frames)
143+
frames_transformed = []
144+
frames_transformed = [
145+
transform_keep_ar(frames[i], self.min_frame_size)
146+
for i in range(self.pass_frames, min(max_frame_to_process+1, len(frames)), self.pass_frames)
147+
]
148+
return key, frames_transformed
137149

138150
def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
139151
df_batch_labels = self._get_dict_from_schema()
@@ -142,28 +154,21 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
142154
for data in batch:
143155
key, frames = data
144156
with torch.no_grad():
145-
for i in range(self.pass_frames, len(frames), self.pass_frames):
146-
current_frame = frames[i - self.pass_frames]
147-
next_frame = frames[i]
157+
for i in range(len(frames)-1):
158+
current_frame = frames[i]
159+
next_frame = frames[i+1]
148160

149-
if (i - self.pass_frames) == 0:
161+
if i == 0:
150162
current_frame_cuda = current_frame.to(self.device)
151-
next_frame_cuda = next_frame.to(self.device)
152-
153-
_, flow = self.model(
154-
current_frame_cuda,
155-
next_frame_cuda,
156-
iters=20, test_mode=True
157-
)
158163
else:
159164
current_frame_cuda = next_frame_cuda
160-
next_frame_cuda = next_frame.to(self.device)
161165

162-
_, flow = self.model(
163-
current_frame_cuda,
164-
next_frame_cuda,
165-
iters=20, test_mode=True
166-
)
166+
next_frame_cuda = next_frame.to(self.device)
167+
_, flow = self.model(
168+
current_frame_cuda,
169+
next_frame_cuda,
170+
iters=self.raft_iters, test_mode=True
171+
)
167172

168173
flow = flow.detach().cpu().numpy()
169174
magnitude, angle = cv2.cartToPolar(flow[0][..., 0], flow[0][..., 1])

0 commit comments

Comments
 (0)