- Notifications
You must be signed in to change notification settings - Fork 5.9k
feature/analysis node representation #10522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b2e2b82 52e6b20 1957680 1daaffd d9be253 8db1325 98265c7 17469ef e5383c0 897326c 47f0f5b 4088304 b086f53 094c0b1 db4a076 File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| cc_library(dot SRCS dot.cc) | ||
| cc_library(analysis SRCS dot.cc node.cc node.h) | ||
| cc_test(test_node SRCS node_tester.cc DEPS analysis) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
| | ||
| namespace paddle { | ||
| namespace inference { | ||
| namespace analysis { | ||
| | ||
| enum class Device { CPU, GPU }; | ||
| | ||
| } // namespace analysis | ||
| } // namespace inference | ||
| } // namespace paddle |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
| | ||
| #include "paddle/fluid/inference/analysis/dot.h" | ||
| | ||
| #include <gtest/gtest.h> | ||
| #include <memory> | ||
| #include "paddle/fluid/inference/analysis/data_flow_graph.h" | ||
| | ||
| namespace paddle { | ||
| namespace inference { | ||
| namespace analysis { | ||
| | ||
| class DotTester : public ::testing::Test { | ||
| protected: | ||
| void SetUp() override { | ||
| std::vector<Dot::Attr> attrs({{"title", "hello"}}); | ||
| dot.reset(new Dot(attrs)); | ||
| dot->AddNode("a", {Dot::Attr{"shape", "box"}, Dot::Attr("color", "blue")}); | ||
| dot->AddNode("b", {}); | ||
| dot->AddNode("c", {}); | ||
| dot->AddEdge("a", "b", {}); | ||
| dot->AddEdge("b", "c", {}); | ||
| dot->AddEdge("a", "c", {}); | ||
| } | ||
| | ||
| std::unique_ptr<Dot> dot; | ||
| }; | ||
| | ||
| TEST_F(DotTester, Build) { | ||
| auto codes = dot->Build(); | ||
| // Output the DOT language code, the generated codes are too long to compare | ||
| // the string. | ||
| // | ||
| // The output is | ||
| // | ||
| // digraph G { | ||
| // title="hello" | ||
| // node_1 | ||
| // node_2 | ||
| // node_0[label="a" shape="box" color="blue"] | ||
| // node_0->node_1 | ||
| // node_1->node_2 | ||
| // node_0->node_2 | ||
| // } // end G | ||
| LOG(INFO) << '\n' << codes; | ||
| } | ||
| | ||
| } // namespace analysis | ||
| } // namespace inference | ||
| } // namespace paddle |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
| | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
| | ||
| #pragma once | ||
| | ||
| #include <string> | ||
| #include <unordered_map> | ||
| #include <vector> | ||
| | ||
| #include "paddle/fluid/platform/enforce.h" | ||
| | ||
| namespace paddle { | ||
| namespace inference { | ||
| namespace analysis { | ||
| | ||
| template <typename IteratorT> | ||
| class iterator_range { | ||
| IteratorT begin_, end_; | ||
| | ||
| public: | ||
| template <typename Container> | ||
| explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {} | ||
| | ||
| iterator_range(const IteratorT &begin, const IteratorT &end) | ||
| : begin_(begin), end_(end) {} | ||
| | ||
| const IteratorT &begin() const { return begin_; } | ||
| const IteratorT &end() const { return end_; } | ||
| }; | ||
| | ||
| /* | ||
| * An registry helper class, with its records keeps the order they registers. | ||
| */ | ||
| template <typename T> | ||
| class OrderedRegistry { | ||
| public: | ||
| T *Register(const std::string &name, T *x) { | ||
| PADDLE_ENFORCE(!dic_.count(name)); | ||
| dic_[name] = data_.size(); | ||
| data_.emplace_back(std::unique_ptr<T>(x)); | ||
| return data_.back().get(); | ||
| } | ||
| | ||
| T *Lookup(const std::string &name) { | ||
| auto it = dic_.find(name); | ||
| if (it == dic_.end()) return nullptr; | ||
| return data_[it->second].get(); | ||
| } | ||
| | ||
| protected: | ||
| std::unordered_map<std::string, int> dic_; | ||
| std::vector<std::unique_ptr<T>> data_; | ||
| }; | ||
| | ||
| } // namespace analysis | ||
| } // namespace inference | ||
| } // namespace paddle | ||
| | ||
| #define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \ | ||
| \ | ||
| type__(const type__ &) = delete; \ | ||
| \ | ||
| void operator=(const type__ &) = delete; | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
| | ||
| #include "paddle/fluid/inference/analysis/node.h" | ||
| #include "glog/logging.h" | ||
| #include "paddle/fluid/platform/enforce.h" | ||
| | ||
| namespace paddle { | ||
| namespace inference { | ||
| namespace analysis { | ||
| | ||
| std::vector<Dot::Attr> Value::dot_attrs() const { | ||
| return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"), | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what are these "style" "shape" thing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are DOT language's attributes, to control the visualization for graph debug. Here, different There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. | ||
| Dot::Attr("shape", "box"), | ||
| Dot::Attr("fillcolor", "red")}); | ||
| } | ||
| | ||
| std::vector<Dot::Attr> Function::dot_attrs() const { | ||
| return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"), | ||
| Dot::Attr("shape", "diamond"), | ||
| Dot::Attr("fillcolor", "yellow")}); | ||
| } | ||
| | ||
| Node *NodeMap::Create(Node::Type type) { | ||
| switch (type) { | ||
| case Node::Type::kFunction: | ||
| nodes_.emplace_back(new Function); | ||
| break; | ||
| case Node::Type::kValue: | ||
| nodes_.emplace_back(new Value); | ||
| break; | ||
| default: | ||
| PADDLE_THROW("Not supported node type."); | ||
| } | ||
| nodes_.back()->id_ = size() - 1; | ||
| return nodes_.back().get(); | ||
| } | ||
| | ||
| Node *NodeMap::GetMutable(size_t id) { | ||
| PADDLE_ENFORCE_GT(size(), id); | ||
| return nodes_[id].get(); | ||
| } | ||
| | ||
| const Node &NodeMap::Get(size_t id) const { | ||
| PADDLE_ENFORCE_GT(size(), id); | ||
| return *nodes_[id].get(); | ||
| } | ||
| | ||
| void NodeMap::Delete(size_t id) { | ||
| PADDLE_ENFORCE_LT(id, size()); | ||
| nodes_[id]->SetDeleted(); | ||
| } | ||
| | ||
| } // namespace analysis | ||
| } // namespace inference | ||
| } // namespace paddle | ||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be std::unique_ptr? if the ownership is transferred.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because the pointer will be used by others and a
shared_ptris a little heavy here, so the implementation of this function isdic_[name] = data_.size(); data_.emplace_back(std::unique_ptr<T>(x)); return data_.back().get();Here just a
Registerand aLookup, without aRemoveAPI, so the owner can't be transferred to others.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not suggesting a shared_ptr here.
Usually, unique_ptr in function def means the ownership of the pointer is tranferred to the function.