Skip to content

Commit d14f203

Browse files
committed
wave
1 parent 8cd75de commit d14f203

File tree

13 files changed

+2074
-2
lines changed

13 files changed

+2074
-2
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
| Component | Status |
99
|----------------------------------|----------------|
1010
| 🧠 Model Definition | ✅ Available |
11-
| 🏋️‍♀️ Training Code | ✅ Available |
11+
| 🏋️‍♀️ Training Code | ⏳ Coming Soon |
1212
| 🧪 Inference Code | ✅ Available |
1313
| 🎯 Model Weights | ⏳ Coming Soon |
1414
| 📊 Dataset | ⏳ Coming Soon |
@@ -20,8 +20,9 @@
2020
- **Status**: Under Review at [AAAI 2026]
2121

2222
---
23+
2324
## Model Weights and Datasets
24-
**Model Weights** and **Datasets** will be made public after the paper is accepted. (*Coming Soon!!!*)
25+
**Training Code**, **Model Weights** and **Datasets** will be made public after the paper is accepted. (*Coming Soon!!!*)
2526

2627
## Abstract
2728
Accurately localizing and segmenting occluded objects from faint light patterns beyond the field of view is highly challenging due to multiple scattering and medium-induced perturbations. Most existing methods, based on real-valued modeling or local convolutional operations, are inadequate for capturing the underlying physics of coherent light propagation. Moreover, under low signal-to-noise conditions, these methods often converge to non-physical solutions, severely compromising the stability and reliability of the observation. To address these challenges, we propose a novel physics-driven Wavefront Propagating Compensation Network (WavePCNet) to simulates wavefront propagation and enhance the perception of occluded objects. This WavePCNet integrates the Tri-Phase Wavefront Complex-Propagation Reprojection (TriWCP) to incorporate complex amplitude transfer operators to precisely constrain coherent propagation behavior, along with a momentum memory mechanism to effectively suppress the accumulation of perturbations. Additionally, a High-frequency Cross-layer Compensation Enhancement is introduced to construct frequency-selective pathways with multi-scale receptive fields and dynamically models structural consistency across layers, further boosting the model’s robustness and interpretability under complex environmental conditions. Extensive experiments conducted on four physically collected datasets demonstrate that WavePCNet consistently outperforms state-of-the-art methods across both accuracy and robustness. All data and code will be publicly released to support and encourage continued research in the obscured object detection domain.
@@ -37,5 +38,7 @@ Accurately localizing and segmenting occluded objects from faint light patterns
3738
## Result
3839
![](./img/result.png)
3940

41+
---
4042

43+
## Test
4144

base/config.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import sys
2+
import os
3+
import argparse
4+
import importlib
5+
from .util import *
6+
7+
def base_config(net_name):
8+
parser = argparse.ArgumentParser()
9+
model_dafault_config = importlib.import_module('methods.{}'.format(net_name)).custom_config
10+
11+
parser.add_argument('model_name', default=net_name, help='Model name')
12+
parser.add_argument('--backbone', default='resnet50', help='Set the backbone of the model')
13+
parser.add_argument('--show_param', action='store_true') # show the number of parameter
14+
15+
# Training schedule
16+
parser.add_argument('--sub', default='base', help='Job name')
17+
parser.add_argument('--clip_gradient', default=0, type=float, help='Max gradient')
18+
parser.add_argument('--weight_decay', default=0.0005, type=float)
19+
parser.add_argument('--data_aug', action='store_false', help='Data augmentation, only random crop')
20+
parser.add_argument('--multi', action='store_false', help='Multi-scale training')
21+
parser.add_argument('--gpus', default='0', type=str, help='Set the gpu devices')
22+
parser.add_argument('--strategy', default='adam_base', help='Training strategy, see base/strategy.py')
23+
parser.add_argument('--batch', default=10, type=int, help='Batch Size for Testing')
24+
25+
parser.add_argument('--training_stage', default=2, type=int, help='Training stage: 1=physical only, 2=physical+network, 3=freeze physical')
26+
# Data setting
27+
parser.add_argument('--size', default=384, type=int, help='Input size')
28+
parser.add_argument('--trset', default='DUTS-TR', help='Set the traing set')
29+
parser.add_argument('--vals', default='all', help='Set the testing sets')
30+
parser.add_argument('--data_path', default='/home/v2-4080s/ouyang/OOD/kong/data/亮度/', help='Dataset path')
31+
parser.add_argument('--save_path', default='./result/our/HKU/meanetA更改/', help='Save path')
32+
parser.add_argument('--weight_path', default='./weight/HKU/meanetA更改/', help='Weight path')
33+
34+
35+
# Testing
36+
parser.add_argument('--resume', action='store_true')
37+
parser.add_argument('--weight', default='', help='Loading weight file')
38+
parser.add_argument('--save', action='store_true', help='Whether save result')
39+
parser.add_argument('--test_batch', default=6, type=int, help='Batch Size for Testing')
40+
parser.add_argument('--debug', action='store_true') # Test model before training
41+
42+
# Use for SALOD dataset
43+
parser.add_argument('--train_split', default=10000, type=int, help='Use for SALOD dataset')
44+
45+
# Construct loss by loss_factory. More details in base/loss.py.
46+
parser.add_argument('--loss', default='bce,iou', type=str, help='Losses for networks')
47+
parser.add_argument('--lw', default='1,1', type=str, help='Weights for losses')
48+
49+
# Customized arguments
50+
### Base arguments with customized values
51+
parser.set_defaults(**model_dafault_config['base'])
52+
53+
### Customized arguments
54+
for k, v in model_dafault_config['customized'].items():
55+
v['dest'] = k[2:]
56+
parser.add_argument(k, **v)
57+
58+
params = parser.parse_args()
59+
config = vars(params)
60+
61+
if config['trset'] == 'SALOD':
62+
config['vals'] = ['SALOD']
63+
elif config['trset'] == 'simple':
64+
config['vals'] = ['tough', 'normal']
65+
elif config['trset'] == 'DUTS-TR':
66+
if config['vals'] == 'all':
67+
config['vals'] = ['PASCAL-S', 'ECSSD', 'HKU-IS', 'DUTS-TE', 'DUT-OMRON']
68+
else:
69+
config['vals'] = config['vals'].split(',')
70+
71+
elif config['trset'] == 'COD-TR':
72+
if config['vals'] == 'all':
73+
config['vals'] = ['COD-TE', 'NC4K', 'CAMO-TE']
74+
else:
75+
config['vals'] = config['vals'].split(',')
76+
else:
77+
config['vals'] = config['vals'].split(',')
78+
79+
save_path = os.path.join(config['save_path'], config['model_name'], config['backbone'], config['sub'])
80+
check_path(save_path)
81+
config['save_path'] = save_path
82+
83+
weight_path = os.path.join(config['weight_path'], config['model_name'], config['backbone'], config['sub'])
84+
check_path(weight_path)
85+
config['weight_path'] = weight_path
86+
87+
return config

base/data.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import os
2+
import torch
3+
import random
4+
import numpy as np
5+
from PIL import Image
6+
import torch.utils.data as data
7+
import torchvision.transforms as transforms
8+
9+
mean = np.array([0.485, 0.456, 0.406]).reshape([1, 1, 3])
10+
std = np.array([0.229, 0.224, 0.225]).reshape([1, 1, 3])
11+
12+
def get_image_list(name, config, phase):
13+
images = []
14+
gts = []
15+
16+
if name == 'all':
17+
print("Skipping 'all' dataset as it does not have a corresponding directory.")
18+
return images, gts
19+
20+
if name in ('simple', 'tough', 'normal'):
21+
train_split = 10000
22+
23+
print('Objectness shifting experiment.')
24+
# Objectness
25+
list_file = 'clean_list.txt'
26+
f = open(os.path.join(config['data_path'], 'SALOD/{}'.format(list_file)), 'r')
27+
if name == 'simple':
28+
img_list = f.readlines()[-train_split:]
29+
elif name == 'normal':
30+
img_list = f.readlines()[train_split:-train_split]
31+
else:
32+
img_list = f.readlines()[:train_split]
33+
34+
for i in range(len(img_list)):
35+
img_list[i] = img_list[i].split(' ')[0]
36+
37+
images = [os.path.join(config['data_path'], 'SALOD/images', line.strip() + '.jpg') for line in img_list]
38+
gts = [os.path.join(config['data_path'], 'SALOD/mask', line.strip() + '.png') for line in img_list]
39+
40+
# Benchmark
41+
elif name == 'SALOD':
42+
f = open(os.path.join(config['data_path'], 'SALOD/{}.txt'.format(phase)), 'r')
43+
img_list = f.readlines()
44+
45+
images = [os.path.join(config['data_path'], name, 'images', line.strip() + '.jpg') for line in img_list]
46+
gts = [os.path.join(config['data_path'], name, 'mask', line.strip() + '.png') for line in img_list]
47+
48+
elif phase == 'test' and os.path.isabs(name):
49+
image_root = os.path.join(name, 'images')
50+
gt_root = os.path.join(name, 'segmentations')
51+
52+
images = sorted([os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')])
53+
gts = sorted([os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png')])
54+
55+
else:
56+
image_root = os.path.join(config['data_path'], name, 'images')
57+
print("Name:", name)
58+
print("Image root:", image_root)
59+
gt_root = os.path.join(config['data_path'], name, 'segmentations')
60+
61+
images = sorted([os.path.join(image_root, f) for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')])
62+
gts = sorted([os.path.join(gt_root, f) for f in os.listdir(gt_root) if f.endswith('.png')])
63+
64+
return images, gts
65+
66+
def get_loader(config):
67+
dataset = Train_Dataset(config['trset'], config)
68+
data_loader = data.DataLoader(dataset=dataset,
69+
batch_size=config['batch'],
70+
shuffle=True,
71+
num_workers=12,
72+
pin_memory=True,
73+
drop_last=True)
74+
return data_loader
75+
76+
77+
78+
class Train_Dataset(data.Dataset):
79+
def __init__(self, name, config):
80+
self.config = config
81+
self.images, self.gts = get_image_list(name, config, 'train')
82+
self.size = len(self.images)
83+
84+
def __getitem__(self, index):
85+
image = Image.open(self.images[index]).convert('RGB')
86+
gt = Image.open(self.gts[index]).convert('L')
87+
88+
img_size = self.config['size']
89+
image = image.resize((img_size, img_size))
90+
gt = gt.resize((img_size, img_size))
91+
92+
image = np.array(image).astype(np.float32)
93+
gt = np.array(gt)
94+
95+
image = ((image / 255.) - mean) / std
96+
image = image.transpose((2, 0, 1))
97+
gt = np.expand_dims((gt > 128).astype(np.float32), axis=0)
98+
99+
image = torch.from_numpy(image).float()
100+
gt = torch.from_numpy(gt).float()
101+
return image, gt
102+
103+
def __len__(self):
104+
return self.size
105+
106+
class Test_Dataset:
107+
def __init__(self, name, config=None):
108+
self.config = config
109+
self.images, self.gts = get_image_list(name, config, 'test')
110+
self.size = len(self.images)
111+
112+
def load_data(self, index):
113+
image = Image.open(self.images[index]).convert('RGB')
114+
image = image.resize((self.config['size'], self.config['size']))
115+
image = np.array(image).astype(np.float32)
116+
gt = np.array(Image.open(self.gts[index]).convert('L'))
117+
gt = Image.open(self.gts[index]).convert('L')
118+
gt = gt.resize((self.config['size'], self.config['size']))
119+
gt = np.array(gt).astype(np.float32)
120+
name = self.images[index].split('/')[-1].split('.')[0]
121+
122+
123+
image = ((image / 255.) - mean) / std
124+
image = image.transpose((2, 0, 1))
125+
image = torch.tensor(np.expand_dims(image, 0)).float()
126+
gt = (gt > 128).astype(np.float32)
127+
return image, gt, name
128+
129+
def test_data():
130+
config = {'orig_size': True, 'size': 288, 'data_path': '../dataset'}
131+
dataset = 'SOD'
132+
data_loader = Test_Dataset(dataset, config)
133+
imgs, gts, names = data_loader.load_all_data()
134+
print(imgs.shape, gts.shape, len(names))
135+
136+
137+
if __name__ == "__main__":
138+
test_data()

0 commit comments

Comments
 (0)