Skip to content

Conversation

@0x45f
Copy link
Contributor

@0x45f 0x45f commented Mar 30, 2022

PR types

Bug fixes

PR changes

Others

Describe

修复rnn在控制流中使用时,rnn调用parent_block.var(name)报错的问题。
问题描述:在静态图下网络参数都会在block0中,rnn的静态图逻辑中会调用parent_block.var去父block中找param,但是如果父block不是block不是block0则会报错。用户提供了如下的动转静代码,改用_find_var_recursive后动转静可以正常导出:

import paddle from paddle import nn import paddle.tensor as tensor import paddle.nn.functional as F import paddle.nn.initializer as I class LSTMCell(nn.RNNCellBase): def __init__(self, input_size: int, hidden_size: int, activation="tanh", weight_ih_attr=None, weight_hh_attr=None, bias_ih_attr=None, bias_hh_attr=None, guass_mean=0.0, guass_std=0.02, name=None): super(LSTMCell, self).__init__() self.weight_ih = self.create_parameter( (4 * hidden_size, input_size), weight_ih_attr, default_initializer=I.Normal(guass_mean, guass_std)) self.weight_hh = self.create_parameter( (4 * hidden_size, hidden_size), weight_hh_attr, default_initializer=I.Normal(guass_mean, guass_std)) self.bias_ih = self.create_parameter( (4 * hidden_size, ), bias_ih_attr, is_bias=True, default_initializer=I.Normal(guass_mean, guass_std)) self.bias_hh = self.create_parameter( (4 * hidden_size, ), bias_hh_attr, is_bias=True, default_initializer=I.Normal(guass_mean, guass_std)) self.hidden_size = hidden_size self.input_size = input_size self.gate_activation = F.sigmoid activation_dict = { 'tanh':paddle.tanh, 'relu':F.relu, 'gelu':F.gelu } if activation not in activation_dict: raise RuntimeError(f"{activation} is not supported in LSTMCell") self.activation = activation_dict[activation] def forward(self, inputs, states=None): # import pdb; pdb.set_trace() if states is None: states = self.get_initial_states(inputs, self.state_shape) prev_h, prev_c = states gates = paddle.matmul(inputs, self.weight_ih, transpose_y=True) if self.bias_ih is not None: gates = gates + self.bias_ih gates += paddle.matmul(prev_h, self.weight_hh, transpose_y=True) if self.bias_hh is not None: gates = gates + self.bias_hh chunked_gates = paddle.split(gates, num_or_sections=4, axis=-1) i = self.gate_activation(chunked_gates[0]) f = self.gate_activation(chunked_gates[1]) g = self.activation(chunked_gates[2]) o = self.gate_activation(chunked_gates[3]) c = f * prev_c + i * g h = o * self.activation(c) return h, (h, c) @property def state_shape(self): r"""  The `state_shape` of LSTMCell is a tuple with two shapes:   `((hidden_size, ), (hidden_size,))`. (-1 for batch size would be   automatically inserted into shape). These two shapes correspond   to :math:`h_{t-1}` and :math:`c_{t-1}` separately.  """ return ((self.hidden_size, ), (self.hidden_size, )) def extra_repr(self): return '{input_size}, {hidden_size}'.format(**self.__dict__) class Decoder(nn.Layer): def __init__(self, input_size, hidden_size): super().__init__() self.cell = LSTMCell(input_size, hidden_size) self.rnn = nn.RNN(self.cell) self.sos = paddle.ones(shape=[1, 1, 2], dtype='float32') self.init_states = (paddle.zeros(shape=[1, 1, 4], dtype='float32'), paddle.zeros(shape=[1, 1, 4], dtype='float32')) self.idx = paddle.zeros(shape=[1], dtype='int32') # self.idx = 0 self.states = self.rnn(self.sos, self.init_states) self.step = 0 def forward(self, inputs, hidden=None, cell=None): if hidden is None: states = self.states else: states = (hidden, cell) # import pdb; pdb.set_trace() if self.idx < 1: outs, states = self.rnn(inputs, states) self.idx += 1 final_states = states return outs, final_states def export(self): static_model = paddle.jit.to_static( self, input_spec=[ paddle.static.InputSpec( shape=[1, 1, 2], dtype='float32'), paddle.static.InputSpec( shape=[1, 1, 4], dtype='float32'), paddle.static.InputSpec( shape=[1, 1, 4], dtype='float32') ] ) return static_model model = Decoder(2, 4) model.eval() static_model = model.export() paddle.jit.save(static_model, "test_model")
Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@guoshengCS guoshengCS left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Aurelius84 Aurelius84 merged commit a54ec5a into PaddlePaddle:develop Mar 31, 2022
@0x45f 0x45f deleted the dy2st_cond_with_rnn branch March 31, 2022 08:24
@0x45f 0x45f changed the title Fix parent_block.var(name) error in static mode Fix parent_block.var(name) error in static mode for RNN Apr 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

4 participants