Skip to content

Commit 149e384

Browse files
authored
make eigenvalue float instead of tensor (#9)
1 parent 7f3ec46 commit 149e384

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

example/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def extra_args(parser):
2424
help='power iteration momentum term')
2525
parser.add_argument('--num_steps', default=20, type=int,
2626
help='number of power iter steps')
27+
parser.add_argument('--cuda', action='store_true',
28+
help='if true, use CUDA/GPUs')
2729

2830

2931
def main(args):
@@ -38,7 +40,8 @@ def main(args):
3840
criterion,
3941
args.num_eigenthings,
4042
args.num_steps,
41-
momentum=args.momentum)
43+
momentum=args.momentum,
44+
use_gpu=args.cuda)
4245
print("Eigenvecs:")
4346
print(eigenvecs)
4447
print("Eigenvals:")

hessian_eigenthings/power_iter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _deflate(x, val, vec):
6363
def _new_op_fn(x, op=current_op, val=eigenval, vec=eigenvec):
6464
return op.apply(x) - _deflate(x, val, vec)
6565
current_op = LambdaOperator(_new_op_fn, operator.size)
66-
eigenvals.append(eigenval.item())
66+
eigenvals.append(eigenval)
6767
eigenvecs.append(eigenvec.cpu())
6868

6969
return eigenvals, eigenvecs
@@ -88,7 +88,7 @@ def power_iteration(operator, steps=20, error_threshold=1e-4,
8888
new_vec = operator.apply(vec) - momentum * prev_vec
8989
prev_vec = vec / torch.norm(vec)
9090

91-
lambda_estimate = vec.dot(new_vec)
91+
lambda_estimate = vec.dot(new_vec).item()
9292
diff = lambda_estimate - prev_lambda
9393
vec = new_vec.detach() / torch.norm(new_vec)
9494
error = np.abs(diff / lambda_estimate)

0 commit comments

Comments
 (0)