-
Couldn't load subscription status.
- Fork 5.9k
Refine InferShape for recurrent_network_op #3124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
* the tensor only contains shape and does not hold memory when inferring shape.
… rnn_infershape
| const std::vector<Link>& outlinks, | ||
| const size_t seq_len) { | ||
| const size_t seq_len, | ||
| bool infer_shape) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_infer or infer_mode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
infer_shape_mode
| std::shared_ptr<Scope> linked_scope = scopes[step_id + offset]; | ||
| for (auto& attr : memories) { | ||
| auto mem = scope->CreateVariable(attr.pre_var)->GetMutable<Tensor>(); | ||
| auto mem = scope->GetVariable(attr.pre_var)->GetMutable<Tensor>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need an enforce on scope->GetVariable(xxx) != nullptr ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| for (size_t i = 0; i < seq_len_; i++) { | ||
| if (i > 0) { | ||
| rnn::LinkMemories(step_scopes, arg_->memories, i, -1); | ||
| rnn::LinkMemories(step_scopes, arg_->memories, i, -1, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true /* infer_mode */
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| step_scopes[0]->GetVariable(outlinks[i].external)->GetMutable<Tensor>(); | ||
| output->Resize(make_ddim(dims_vec)); | ||
| } | ||
| rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as top
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| // maybe remove following code after testing | ||
| if (step_id > 0) { | ||
| rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1); | ||
| rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as the top
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| | ||
| net->AddOp( | ||
| OpRegistry::CreateOp("add_two", {"rnn/x", "rnn/s"}, {"rnn/h"}, {})); | ||
| OpRegistry::CreateOp("add_two", {"x@alias", "rnn/s"}, {"rnn/h"}, {})); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
an inline function for the action of adding @alias suffix ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This unit test may move to Python, so not add an inline function for adding @alias suffix.
| ->GetMutable<std::vector<std::shared_ptr<Scope>>>(); | ||
| for (int i = 1; i < 10; ++i) { | ||
| rnn::LinkMemories(*step_scopes, memories, i, -1); | ||
| rnn::LinkMemories(*step_scopes, memories, i, -1, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as the top, `true /* infer_mode */
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| recurrentOp里可以调用infershape,也可以不调用infershape?请问分别是什么时候呢 |
| auto tensor = scope->CreateVariable("h")->GetMutable<Tensor>(); | ||
| float* data = tensor->mutable_data<float>(make_ddim({15, 20}), CPUPlace()); | ||
| for (int i = 0; i < 15 * 20; ++i) { | ||
| for (int j = 0; j < 15 * 20; ++j) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data[i] ==> data[j] ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| void SegmentInputs(std::vector<std::shared_ptr<Scope>>& step_scopes, | ||
| const std::vector<Link>& inlinks, | ||
| const size_t seq_len) { | ||
| const size_t seq_len, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const size_t&
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const size_t is ok.
| void ConcatOutputs(std::vector<std::shared_ptr<Scope>>& step_scopes, | ||
| const std::vector<Link>& outlinks, | ||
| const size_t seq_len) { | ||
| const size_t seq_len, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const size_t&
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const size_t is ok.
| const std::vector<rnn::MemoryAttr>& memories, | ||
| size_t step_id, | ||
| int offset) { | ||
| const size_t step_id, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const size_t&
| } | ||
| LinkBootMemoryGradients(step_scopes[0]); | ||
| rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_); | ||
| LinkBootMemoryGradients(step_scopes[0], false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
false add a comment ?
| rnn::LinkMemories(step_scopes, memories, i, -1, false /*infer_shape_mode*/); | ||
| } | ||
| // check | ||
| for (int i = 0; i < len - 1; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unify all the array index from int to size_t
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
The tensor only contains the shape and does not hold memory when inferring shape.