1+ """A module containing objects to instantiate various neural network components."""
12import re
23from functools import partial
3- from ast import literal_eval as eval
4+ from ast import literal_eval as _eval
45
56import numpy as np
67
78from ..optimizers import OptimizerBase , SGD , AdaGrad , RMSProp , Adam
8- from ..activations import ActivationBase , Affine , ReLU , Tanh , Sigmoid , LeakyReLU
9+ from ..activations import (
10+ ELU ,
11+ GELU ,
12+ SELU ,
13+ ReLU ,
14+ Tanh ,
15+ Affine ,
16+ Sigmoid ,
17+ Identity ,
18+ SoftPlus ,
19+ LeakyReLU ,
20+ Exponential ,
21+ HardSigmoid ,
22+ ActivationBase ,
23+ )
924from ..schedulers import (
1025 SchedulerBase ,
1126 ConstantScheduler ,
2641class ActivationInitializer (object ):
2742 def __init__ (self , param = None ):
2843 """
29- A class for initializing activation functions. Valid inputs are:
30- (a) __str__ representations of `ActivationBase` instances
31- (b) `ActivationBase` instances
44+ A class for initializing activation functions. Valid `param` values
45+ are:
46+ (a) ``__str__`` representations of an `ActivationBase` instance
47+ (b) `ActivationBase` instance
3248
3349 If `param` is `None`, return the identity function: f(X) = X
3450 """
3551 self .param = param
3652
3753 def __call__ (self ):
54+ """Initialize activation function"""
3855 param = self .param
3956 if param is None :
40- act = Affine ( slope = 1 , intercept = 0 )
57+ act = Identity ( )
4158 elif isinstance (param , ActivationBase ):
4259 act = param
4360 elif isinstance (param , str ):
@@ -47,13 +64,24 @@ def __call__(self):
4764 return act
4865
4966 def init_from_str (self , act_str ):
67+ """Initialize activation function from the `param` string"""
5068 act_str = act_str .lower ()
5169 if act_str == "relu" :
5270 act_fn = ReLU ()
5371 elif act_str == "tanh" :
5472 act_fn = Tanh ()
73+ elif act_str == "selu" :
74+ act_fn = SELU ()
5575 elif act_str == "sigmoid" :
5676 act_fn = Sigmoid ()
77+ elif act_str == "identity" :
78+ act_fn = Identity ()
79+ elif act_str == "hardsigmoid" :
80+ act_fn = HardSigmoid ()
81+ elif act_str == "softplus" :
82+ act_fn = SoftPlus ()
83+ elif act_str == "exponential" :
84+ act_fn = Exponential ()
5785 elif "affine" in act_str :
5886 r = r"affine\(slope=(.*), intercept=(.*)\)"
5987 slope , intercept = re .match (r , act_str ).groups ()
@@ -62,6 +90,14 @@ def init_from_str(self, act_str):
6290 r = r"leaky relu\(alpha=(.*)\)"
6391 alpha = re .match (r , act_str ).groups ()[0 ]
6492 act_fn = LeakyReLU (float (alpha ))
93+ elif "gelu" in act_str :
94+ r = r"gelu\(approximate=(.*)\)"
95+ approx = re .match (r , act_str ).groups ()[0 ] == "true"
96+ act_fn = GELU (approximation = approx )
97+ elif "elu" in act_str :
98+ r = r"elu\(alpha=(.*)\)"
99+ approx = re .match (r , act_str ).groups ()[0 ]
100+ act_fn = ELU (alpha = float (alpha ))
65101 else :
66102 raise ValueError ("Unknown activation: {}" .format (act_str ))
67103 return act_fn
@@ -70,7 +106,8 @@ def init_from_str(self, act_str):
70106class SchedulerInitializer (object ):
71107 def __init__ (self , param = None , lr = None ):
72108 """
73- A class for initializing learning rate schedulers. Valid inputs are:
109+ A class for initializing learning rate schedulers. Valid `param` values
110+ are:
74111 (a) __str__ representations of `SchedulerBase` instances
75112 (b) `SchedulerBase` instances
76113 (c) Parameter dicts (e.g., as produced via the `summary` method in
@@ -86,6 +123,7 @@ def __init__(self, param=None, lr=None):
86123 self .param = param
87124
88125 def __call__ (self ):
126+ """Initialize scheduler"""
89127 param = self .param
90128 if param is None :
91129 scheduler = ConstantScheduler (self .lr )
@@ -98,9 +136,10 @@ def __call__(self):
98136 return scheduler
99137
100138 def init_from_str (self ):
139+ """Initialize scheduler from the param string"""
101140 r = r"([a-zA-Z]*)=([^,)]*)"
102141 sch_str = self .param .lower ()
103- kwargs = dict ([( i , eval (j )) for ( i , j ) in re .findall (r , sch_str )])
142+ kwargs = { i : _eval (j ) for i , j in re .findall (r , sch_str )}
104143
105144 if "constant" in sch_str :
106145 scheduler = ConstantScheduler (** kwargs )
@@ -115,6 +154,7 @@ def init_from_str(self):
115154 return scheduler
116155
117156 def init_from_dict (self ):
157+ """Initialize scheduler from the param dictionary"""
118158 S = self .param
119159 sc = S ["hyperparameters" ] if "hyperparameters" in S else None
120160
@@ -136,7 +176,7 @@ def init_from_dict(self):
136176class OptimizerInitializer (object ):
137177 def __init__ (self , param = None ):
138178 """
139- A class for initializing optimizers. Valid inputs are:
179+ A class for initializing optimizers. Valid `param` values are:
140180 (a) __str__ representations of `OptimizerBase` instances
141181 (b) `OptimizerBase` instances
142182 (c) Parameter dicts (e.g., as produced via the `summary` method in
@@ -147,6 +187,7 @@ def __init__(self, param=None):
147187 self .param = param
148188
149189 def __call__ (self ):
190+ """Initialize the optimizer"""
150191 param = self .param
151192 if param is None :
152193 opt = SGD ()
@@ -159,9 +200,10 @@ def __call__(self):
159200 return opt
160201
161202 def init_from_str (self ):
203+ """Initialize optimizer from the `param` string"""
162204 r = r"([a-zA-Z]*)=([^,)]*)"
163205 opt_str = self .param .lower ()
164- kwargs = dict ([( i , eval (j )) for ( i , j ) in re .findall (r , opt_str )])
206+ kwargs = { i : _eval (j ) for i , j in re .findall (r , opt_str )}
165207 if "sgd" in opt_str :
166208 optimizer = SGD (** kwargs )
167209 elif "adagrad" in opt_str :
@@ -175,12 +217,13 @@ def init_from_str(self):
175217 return optimizer
176218
177219 def init_from_dict (self ):
178- O = self .param
179- cc = O ["cache" ] if "cache" in O else None
180- op = O ["hyperparameters" ] if "hyperparameters" in O else None
220+ """Initialize optimizer from the `param` dictonary"""
221+ D = self .param
222+ cc = D ["cache" ] if "cache" in D else None
223+ op = D ["hyperparameters" ] if "hyperparameters" in D else None
181224
182225 if op is None :
183- raise ValueError ("Must have `hyperparemeters` key: {}" . format ( O ) )
226+ raise ValueError ("`param` dictionary has no `hyperparemeters` key" )
184227
185228 if op and op ["id" ] == "SGD" :
186229 optimizer = SGD ()
@@ -237,6 +280,7 @@ def __init__(self, act_fn_str, mode="glorot_uniform"):
237280 self ._fn = partial (truncated_normal , mean = 0 , std = 1 )
238281
239282 def __call__ (self , weight_shape ):
283+ """Initialize weights according to the specified strategy"""
240284 if "glorot" in self .mode :
241285 gain = self ._calc_glorot_gain ()
242286 W = self ._fn (weight_shape , gain )
0 commit comments