11import  io 
22import  os 
3- from  typing  import  Any 
3+ from  typing  import  Any ,  Optional 
44from  urllib .request  import  urlopen 
55from  zipfile  import  ZipFile 
66
@@ -28,6 +28,24 @@ def transform_frame(frame: MatLike, target_size: tuple[int, int]) -> Tensor:
2828 return  frame_tensor 
2929
3030
31+ def  transform_keep_ar (frame : MatLike , min_side_size : int ) ->  Tensor :
32+  h , w  =  frame .shape [:2 ]
33+  aspect_ratio  =  w  /  h 
34+  if  h  <=  w :
35+  new_height  =  min_side_size 
36+  new_width  =  int (aspect_ratio  *  new_height )
37+  else :
38+  new_width  =  min_side_size 
39+  new_height  =  int (new_width  /  aspect_ratio )
40+ 
41+  frame  =  cv2 .resize (frame , dsize = (new_width , new_height ), interpolation = cv2 .INTER_LINEAR )
42+  frame_tensor  =  torch .from_numpy (frame ).permute (2 , 0 , 1 ).float ()[None ]
43+ 
44+  padder  =  InputPadder (frame_tensor .shape ) # type: ignore 
45+  frame_tensor  =  padder .pad (frame_tensor )[0 ]
46+  return  frame_tensor 
47+ 
48+ 
3149class  InputPadder :
3250 """ Pads images such that dimensions are divisible by 8 """ 
3351
@@ -62,20 +80,24 @@ class RAFTOpticalFlowFilter(VideoFilter):
6280 def  __init__ (
6381 self ,
6482 pass_frames : int  =  10 ,
83+  num_passes : Optional [int ] =  None ,
84+  min_frame_size : int  =  512 ,
6585 use_small_model : bool  =  False ,
86+  raft_iters : int  =  20 ,
6687 device : str  =  "cuda:0" ,
6788 workers : int  =  16 ,
68-  batch_size : int  =  1 ,
6989 pbar : bool  =  True ,
7090 _pbar_position : int  =  0 
7191 ):
7292 super ().__init__ (pbar , _pbar_position )
7393 self .num_workers  =  workers 
74-  self .batch_size  =  batch_size 
7594 self .device  =  device 
7695
7796 assert  pass_frames  >=  1 , "Number of pass_frames should be greater or equal to 1." 
7897 self .pass_frames  =  pass_frames 
98+  self .num_passes  =  num_passes 
99+  self .min_frame_size  =  min_frame_size 
100+  self .raft_iters  =  raft_iters 
79101
80102 resp  =  urlopen (WEIGHTS_URL )
81103 zipped_files  =  ZipFile (io .BytesIO (resp .read ()))
@@ -98,13 +120,13 @@ def __init__(
98120
99121 @property  
100122 def  result_columns (self ) ->  list [str ]:
101-  return  [f"mean_optical_flow_ { self .model_name }  " ]
123+  return  [f"optical_flow_ { self .model_name }  " ]
102124
103125 @property  
104126 def  dataloader_kwargs (self ) ->  dict [str , Any ]:
105127 return  {
106128 "num_workers" : self .num_workers ,
107-  "batch_size" : self . batch_size ,
129+  "batch_size" : 1 ,
108130 "drop_last" : False ,
109131 }
110132
@@ -117,23 +139,13 @@ def preprocess_data(
117139 video_file  =  modality2data ['video' ]
118140
119141 frames  =  iio .imread (io .BytesIO (video_file ), plugin = "pyav" )
120- 
121-  if  frames .shape [1 ] >  frames .shape [2 ]:
122-  frames_resized  =  [
123-  transform_frame (frame = frames [i ], target_size = (450 , 800 ))
124-  for  i  in  range (self .pass_frames , len (frames ), self .pass_frames )
125-  ]
126-  elif  frames .shape [2 ] >  frames .shape [1 ]:
127-  frames_resized  =  [
128-  transform_frame (frame = frames [i ], target_size = (800 , 450 ))
129-  for  i  in  range (self .pass_frames , len (frames ), self .pass_frames )
130-  ]
131-  else :
132-  frames_resized  =  [
133-  transform_frame (frame = frames [i ], target_size = (450 , 450 ))
134-  for  i  in  range (self .pass_frames , len (frames ), self .pass_frames )
135-  ]
136-  return  key , frames_resized 
142+  max_frame_to_process  =  self .num_passes * self .pass_frames  if  self .num_passes  else  len (frames )
143+  frames_transformed  =  []
144+  frames_transformed  =  [
145+  transform_keep_ar (frames [i ], self .min_frame_size )
146+  for  i  in  range (self .pass_frames , min (max_frame_to_process + 1 , len (frames )), self .pass_frames )
147+  ]
148+  return  key , frames_transformed 
137149
138150 def  process_batch (self , batch : list [Any ]) ->  dict [str , list [Any ]]:
139151 df_batch_labels  =  self ._get_dict_from_schema ()
@@ -142,28 +154,21 @@ def process_batch(self, batch: list[Any]) -> dict[str, list[Any]]:
142154 for  data  in  batch :
143155 key , frames  =  data 
144156 with  torch .no_grad ():
145-  for  i  in  range (self . pass_frames ,  len (frames ),  self . pass_frames ):
146-  current_frame  =  frames [i   -   self . pass_frames ]
147-  next_frame  =  frames [i ]
157+  for  i  in  range (len (frames )- 1 ):
158+  current_frame  =  frames [i ]
159+  next_frame  =  frames [i + 1 ]
148160
149-  if  ( i   -   self . pass_frames )  ==  0 :
161+  if  i  ==  0 :
150162 current_frame_cuda  =  current_frame .to (self .device )
151-  next_frame_cuda  =  next_frame .to (self .device )
152- 
153-  _ , flow  =  self .model (
154-  current_frame_cuda ,
155-  next_frame_cuda ,
156-  iters = 20 , test_mode = True 
157-  )
158163 else :
159164 current_frame_cuda  =  next_frame_cuda 
160-  next_frame_cuda  =  next_frame .to (self .device )
161165
162-  _ , flow  =  self .model (
163-  current_frame_cuda ,
164-  next_frame_cuda ,
165-  iters = 20 , test_mode = True 
166-  )
166+  next_frame_cuda  =  next_frame .to (self .device )
167+  _ , flow  =  self .model (
168+  current_frame_cuda ,
169+  next_frame_cuda ,
170+  iters = self .raft_iters , test_mode = True 
171+  )
167172
168173 flow  =  flow .detach ().cpu ().numpy ()
169174 magnitude , angle  =  cv2 .cartToPolar (flow [0 ][..., 0 ], flow [0 ][..., 1 ])
0 commit comments