-   Notifications  You must be signed in to change notification settings 
- Fork 9.8k
Closed
Description
def _get_orthogonal_init_weights(weights):
 fan_out = weights.size(0)
 fan_in = weights.size(1) * weights.size(2) * weights.size(3)
 u, _, v = svd(normal(0.0, 1.0, (fan_out, fan_in)), full_matrices=False)
 if u.shape == (fan_out, fan_in):
 return torch.Tensor(u.reshape(weights.size()))
 else:
 return torch.Tensor(v.reshape(weights.size()))
Why do the above operation?
Metadata
Metadata
Assignees
Labels
No labels