Skip to content

Commit 50d4940

Browse files
committed
some reformatting; progress logging; dataparallel check
1 parent 554b276 commit 50d4940

File tree

3 files changed

+70
-39
lines changed

3 files changed

+70
-39
lines changed

example/main.py

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,41 +14,69 @@
1414

1515

1616
def extra_args(parser):
17-
parser.add_argument('--num_eigenthings', default=5, type=int,
18-
help='number of eigenvals/vecs to compute')
19-
parser.add_argument('--batch_size', default=128, type=int,
20-
help='train set batch size')
21-
parser.add_argument('--eval_batch_size', default=16, type=int,
22-
help='test set batch size')
23-
parser.add_argument('--momentum', default=0.0, type=float,
24-
help='power iteration momentum term')
25-
parser.add_argument('--num_steps', default=20, type=int,
26-
help='number of power iter steps')
27-
parser.add_argument('--cuda', action='store_true',
28-
help='if true, use CUDA/GPUs')
17+
parser.add_argument(
18+
"--num_eigenthings",
19+
default=5,
20+
type=int,
21+
help="number of eigenvals/vecs to compute",
22+
)
23+
parser.add_argument(
24+
"--batch_size", default=128, type=int, help="train set batch size"
25+
)
26+
parser.add_argument(
27+
"--eval_batch_size", default=16, type=int, help="test set batch size"
28+
)
29+
parser.add_argument(
30+
"--momentum", default=0.0, type=float, help="power iteration momentum term"
31+
)
32+
parser.add_argument(
33+
"--num_steps", default=50, type=int, help="number of power iter steps"
34+
)
35+
parser.add_argument("--max_samples", default=2048, type=int)
36+
parser.add_argument("--cuda", action="store_true", help="if true, use CUDA/GPUs")
37+
parser.add_argument(
38+
"--full_dataset",
39+
action="store_true",
40+
help="if true,\
41+
loop over all batches in set for each gradient step",
42+
)
43+
parser.add_argument("--fname", default="", type=str)
44+
parser.add_argument("--mode", type=str, choices=["power_iter", "lanczos"])
2945

3046

3147
def main(args):
32-
trainloader, testloader = build_dataset('cifar10',
33-
dataroot=args.dataroot,
34-
batch_size=args.batch_size,
35-
eval_batch_size=args.eval_batch_size,
36-
num_workers=2)
37-
model = build_model('ResNet18', num_classes=10)
48+
trainloader, testloader = build_dataset(
49+
"cifar10",
50+
dataroot=args.dataroot,
51+
batch_size=args.batch_size,
52+
eval_batch_size=args.eval_batch_size,
53+
num_workers=2,
54+
)
55+
if args.fname:
56+
print("Loading model from %s" % args.fname)
57+
model = torch.load(args.fname, map_location="cpu").cuda()
58+
else:
59+
model = build_model("ResNet18", num_classes=10)
3860
criterion = torch.nn.CrossEntropyLoss()
39-
eigenvals, eigenvecs = compute_hessian_eigenthings(model, testloader,
40-
criterion,
41-
args.num_eigenthings,
42-
args.num_steps,
43-
momentum=args.momentum,
44-
use_gpu=args.cuda)
61+
eigenvals, eigenvecs = compute_hessian_eigenthings(
62+
model,
63+
testloader,
64+
criterion,
65+
args.num_eigenthings,
66+
mode=args.mode,
67+
# power_iter_steps=args.num_steps,
68+
max_samples=args.max_samples,
69+
# momentum=args.momentum,
70+
full_dataset=args.full_dataset,
71+
use_gpu=args.cuda,
72+
)
4573
print("Eigenvecs:")
4674
print(eigenvecs)
4775
print("Eigenvals:")
4876
print(eigenvals)
49-
track.metric(iteration=0, eigenvals=eigenvals)
77+
# track.metric(iteration=0, eigenvals=eigenvals)
5078

5179

52-
if __name__ == '__main__':
80+
if __name__ == "__main__":
5381
skeletor.supply_args(extra_args)
5482
skeletor.execute(main)

hessian_eigenthings/__init__.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
""" Top-level module for hessian eigenvec computation """
2-
from hessian_eigenthings.power_iter import power_iteration,\
3-
deflated_power_iteration
2+
from hessian_eigenthings.power_iter import power_iteration, deflated_power_iteration
43
from hessian_eigenthings.lanczos import lanczos
5-
from hessian_eigenthings.hvp_operator import HVPOperator,\
6-
compute_hessian_eigenthings
4+
from hessian_eigenthings.hvp_operator import HVPOperator, compute_hessian_eigenthings
75

86
__all__ = [
9-
'power_iteration',
10-
'deflated_power_iteration',
11-
'lanczos',
12-
'HVPOperator',
13-
'compute_hessian_eigenthings'
7+
"power_iteration",
8+
"deflated_power_iteration",
9+
"lanczos",
10+
"HVPOperator",
11+
"compute_hessian_eigenthings",
1412
]
1513

16-
name = 'hessian_eigenthings'
14+
name = "hessian_eigenthings"

hessian_eigenthings/power_iter.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
import torch
77

8+
from hessian_eigenthings.utils import log, progress_bar
9+
810

911
class Operator:
1012
"""
@@ -59,7 +61,9 @@ def deflated_power_iteration(
5961
def _deflate(x, val, vec):
6062
return val * vec.dot(x) * vec
6163

62-
for _ in range(num_eigenthings):
64+
log("beginning deflated power iteration")
65+
for i in range(num_eigenthings):
66+
log("computing eigenvalue/vector %d of %d" % (i + 1, num_eigenthings))
6367
eigenval, eigenvec = power_iteration(
6468
current_op,
6569
power_iter_steps,
@@ -68,6 +72,7 @@ def _deflate(x, val, vec):
6872
use_gpu=use_gpu,
6973
init_vec=prev_vec,
7074
)
75+
log("eigenvalue %d: %.4f" % (i + 1, eigenval))
7176

7277
def _new_op_fn(x, op=current_op, val=eigenval, vec=eigenvec):
7378
return op.apply(x) - _deflate(x, val, vec)
@@ -111,14 +116,14 @@ def power_iteration(
111116

112117
prev_lambda = 0.0
113118
prev_vec = torch.zeros_like(vec)
114-
for _ in range(steps):
119+
for i in range(steps):
115120
new_vec = operator.apply(vec) - momentum * prev_vec
116121
prev_vec = vec / (torch.norm(vec) + 1e-6)
117-
118122
lambda_estimate = vec.dot(new_vec).item()
119123
diff = lambda_estimate - prev_lambda
120124
vec = new_vec.detach() / torch.norm(new_vec)
121125
error = np.abs(diff / lambda_estimate)
126+
progress_bar(i, steps, "power iter error: %.4f" % error)
122127
if error < error_threshold:
123128
return lambda_estimate, vec
124129
prev_lambda = lambda_estimate

0 commit comments

Comments
 (0)