Skip to content

Commit 92d65fe

Browse files
mrrytensorflower-gardener
authored andcommitted
Add ConfigProto.isolate_session_state option for the distributed runtime.
Setting this option to true when creating a session ensures that no stateful resources (variables, queues, iterators, etc.) will be visible to any other session running on the same server, and those resources will be deleted when the session is closed. The default behavior, namely that all `tf.Variable` objects are shared by default and most other resources are shared when their `shared_name` attr is non-empty, is preserved. This change augments the semantics of the WorkerService.CreateWorkerSession RPC. Now, if the server_def in the request is empty, it implies that the worker should use its default ClusterSpec. Note that clusters created using ClusterSpec propagation always have isolated session state, and are unaffected by this change. PiperOrigin-RevId: 177173545
1 parent b262375 commit 92d65fe

File tree

15 files changed

+374
-31
lines changed

15 files changed

+374
-31
lines changed

tensorflow/core/common_runtime/device.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class Device : public DeviceBase {
131131
OpSegment* op_segment() { return &op_seg_; }
132132

133133
// Returns the resource manager associated w/ this device.
134-
ResourceMgr* resource_manager() { return rmgr_; }
134+
virtual ResourceMgr* resource_manager() { return rmgr_; }
135135

136136
// Summarizes the status of this Device, for debugging.
137137
string DebugString() const { return ProtoDebugString(device_attributes_); }

tensorflow/core/common_runtime/renamed_device.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ namespace tensorflow {
2121
/* static */
2222
Device* RenamedDevice::NewRenamedDevice(const string& new_base,
2323
Device* underlying,
24-
bool owns_underlying) {
24+
bool owns_underlying,
25+
bool isolate_session_state) {
2526
DeviceNameUtils::ParsedName parsed_name;
2627
CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name));
2728
DeviceNameUtils::ParsedName underlying_parsed_name =
@@ -35,15 +36,17 @@ Device* RenamedDevice::NewRenamedDevice(const string& new_base,
3536
parsed_name.id);
3637
DeviceAttributes attributes(underlying->attributes());
3738
attributes.set_name(name);
38-
return new RenamedDevice(underlying, attributes, owns_underlying);
39+
return new RenamedDevice(underlying, attributes, owns_underlying,
40+
isolate_session_state);
3941
}
4042

4143
RenamedDevice::RenamedDevice(Device* underlying,
4244
const DeviceAttributes& attributes,
43-
bool owns_underlying)
45+
bool owns_underlying, bool isolate_session_state)
4446
: Device(underlying->env(), attributes),
4547
underlying_(underlying),
46-
owns_underlying_(owns_underlying) {}
48+
owns_underlying_(owns_underlying),
49+
isolate_session_state_(isolate_session_state) {}
4750

4851
RenamedDevice::~RenamedDevice() {
4952
if (owns_underlying_) {

tensorflow/core/common_runtime/renamed_device.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ namespace tensorflow {
2929
class RenamedDevice : public Device {
3030
public:
3131
static Device* NewRenamedDevice(const string& new_base, Device* underlying,
32-
bool owns_underlying);
32+
bool owns_underlying,
33+
bool isolate_session_state);
34+
3335
~RenamedDevice() override;
3436

3537
// Below are virtual methods defined on DeviceBase
@@ -113,11 +115,21 @@ class RenamedDevice : public Device {
113115
return underlying_->FillContextMap(graph, device_context_map);
114116
}
115117

118+
// Returns the resource manager associated w/ this device.
119+
ResourceMgr* resource_manager() override {
120+
if (isolate_session_state_) {
121+
return Device::resource_manager();
122+
} else {
123+
return underlying_->resource_manager();
124+
}
125+
}
126+
116127
private:
117128
RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
118-
bool owns_underlying);
129+
bool owns_underlying, bool isolate_session_state);
119130
Device* const underlying_;
120131
const bool owns_underlying_;
132+
const bool isolate_session_state_;
121133
};
122134

123135
} // namespace tensorflow

tensorflow/core/distributed_runtime/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ cc_library(
140140
hdrs = ["session_mgr.h"],
141141
deps = [
142142
":graph_mgr",
143+
":worker_cache_wrapper",
143144
":worker_session",
144145
"//tensorflow/core:core_cpu_internal",
145146
"//tensorflow/core:lib",
@@ -263,6 +264,17 @@ cc_library(
263264
],
264265
)
265266

267+
cc_library(
268+
name = "worker_cache_wrapper",
269+
hdrs = ["worker_cache_wrapper.h"],
270+
deps = [
271+
":worker_cache",
272+
":worker_interface",
273+
"//tensorflow/core:lib",
274+
"//tensorflow/core:protos_all_cc",
275+
],
276+
)
277+
266278
cc_library(
267279
name = "remote_device",
268280
srcs = ["remote_device.cc"],

tensorflow/core/distributed_runtime/master_session.cc

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,10 @@ Status MasterSession::Create(GraphDef* graph_def,
10491049
TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
10501050
graph_def, execution_options, &execution_state_));
10511051
}
1052-
if (options.cluster_def != nullptr) {
1052+
// TODO(b/36574172): Remove these conditions when ClusterSpec
1053+
// propagation is supported in all servers.
1054+
if (options.cluster_def != nullptr ||
1055+
session_opts_.config.isolate_session_state()) {
10531056
should_delete_worker_sessions_ = true;
10541057
return CreateWorkerSessions(options);
10551058
}
@@ -1058,10 +1061,9 @@ Status MasterSession::Create(GraphDef* graph_def,
10581061

10591062
Status MasterSession::CreateWorkerSessions(
10601063
const WorkerCacheFactoryOptions& options) {
1061-
CHECK(worker_cache_) << "CreateWorkerSessions should be called only with "
1062-
<< "dynamic cluster membership.";
10631064
std::vector<string> worker_names;
1064-
worker_cache_->ListWorkers(&worker_names);
1065+
WorkerCacheInterface* worker_cache = get_worker_cache();
1066+
worker_cache->ListWorkers(&worker_names);
10651067

10661068
struct WorkerGroup {
10671069
// The worker name. (Not owned.)
@@ -1079,10 +1081,10 @@ Status MasterSession::CreateWorkerSessions(
10791081
std::vector<WorkerGroup> workers(worker_names.size());
10801082

10811083
// Release the workers.
1082-
auto cleanup = gtl::MakeCleanup([this, &workers] {
1084+
auto cleanup = gtl::MakeCleanup([this, &workers, worker_cache] {
10831085
for (auto&& worker_group : workers) {
10841086
if (worker_group.worker != nullptr) {
1085-
worker_cache_->ReleaseWorker(*worker_group.name, worker_group.worker);
1087+
worker_cache->ReleaseWorker(*worker_group.name, worker_group.worker);
10861088
}
10871089
}
10881090
});
@@ -1091,11 +1093,19 @@ Status MasterSession::CreateWorkerSessions(
10911093
// Create all the workers & kick off the computations.
10921094
for (size_t i = 0; i < worker_names.size(); ++i) {
10931095
workers[i].name = &worker_names[i];
1094-
workers[i].worker = worker_cache_->CreateWorker(worker_names[i]);
1096+
workers[i].worker = worker_cache->CreateWorker(worker_names[i]);
10951097
workers[i].request.set_session_handle(handle_);
1096-
*workers[i].request.mutable_server_def()->mutable_cluster() =
1097-
*options.cluster_def;
1098-
workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
1098+
if (options.cluster_def) {
1099+
*workers[i].request.mutable_server_def()->mutable_cluster() =
1100+
*options.cluster_def;
1101+
workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
1102+
// Session state is always isolated when ClusterSpec propagation
1103+
// is in use.
1104+
workers[i].request.set_isolate_session_state(true);
1105+
} else {
1106+
workers[i].request.set_isolate_session_state(
1107+
session_opts_.config.isolate_session_state());
1108+
}
10991109

11001110
DeviceNameUtils::ParsedName name;
11011111
if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
@@ -1162,7 +1172,7 @@ Status MasterSession::DeleteWorkerSessions() {
11621172
// Create all the workers & kick off the computations.
11631173
for (size_t i = 0; i < worker_names.size(); ++i) {
11641174
workers[i].name = &worker_names[i];
1165-
workers[i].worker = worker_cache_->CreateWorker(worker_names[i]);
1175+
workers[i].worker = worker_cache->CreateWorker(worker_names[i]);
11661176
workers[i].request.set_session_handle(handle_);
11671177
}
11681178

tensorflow/core/distributed_runtime/session_mgr.cc

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ limitations under the License.
2020
#include "tensorflow/core/common_runtime/device_mgr.h"
2121
#include "tensorflow/core/common_runtime/renamed_device.h"
2222
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
23+
#include "tensorflow/core/distributed_runtime/worker_cache_wrapper.h"
2324
#include "tensorflow/core/lib/strings/strcat.h"
25+
#include "tensorflow/core/protobuf/cluster.pb.h"
26+
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
2427

2528
namespace tensorflow {
2629

@@ -29,7 +32,10 @@ SessionMgr::SessionMgr(
2932
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
3033
WorkerCacheFactory worker_cache_factory)
3134
: worker_env_(worker_env),
32-
legacy_session_("", default_worker_name, std::move(default_worker_cache),
35+
default_worker_cache_(std::move(default_worker_cache)),
36+
legacy_session_("", default_worker_name,
37+
std::unique_ptr<WorkerCacheInterface>(
38+
new WorkerCacheWrapper(default_worker_cache_.get())),
3339
std::unique_ptr<DeviceMgr>(worker_env->device_mgr),
3440
std::unique_ptr<GraphMgr>(
3541
new GraphMgr(worker_env, worker_env->device_mgr))),
@@ -41,7 +47,8 @@ string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
4147
}
4248

4349
Status SessionMgr::CreateSession(const string& session,
44-
const ServerDef& server_def) {
50+
const ServerDef& server_def,
51+
bool isolate_session_state) {
4552
mutex_lock l(mu_);
4653
if (session.empty()) {
4754
return errors::InvalidArgument("Session must be non-empty.");
@@ -50,12 +57,18 @@ Status SessionMgr::CreateSession(const string& session,
5057
const string worker_name = WorkerNameFromServerDef(server_def);
5158

5259
WorkerCacheInterface* worker_cache = nullptr;
53-
TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
60+
if (server_def.cluster().job().empty()) {
61+
worker_cache = new WorkerCacheWrapper(default_worker_cache_.get());
62+
} else {
63+
TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
64+
}
5465

66+
CHECK(!worker_env_->local_devices.empty())
67+
<< "The WorkerEnv must have at least one device in `local_devices`.";
5568
std::vector<Device*> renamed_devices;
5669
for (Device* d : worker_env_->local_devices) {
57-
renamed_devices.push_back(
58-
RenamedDevice::NewRenamedDevice(worker_name, d, false));
70+
renamed_devices.push_back(RenamedDevice::NewRenamedDevice(
71+
worker_name, d, false, isolate_session_state));
5972
}
6073
std::unique_ptr<DeviceMgr> device_mgr(new DeviceMgr(renamed_devices));
6174

tensorflow/core/distributed_runtime/session_mgr.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ class SessionMgr {
4545
~SessionMgr() {}
4646

4747
// Allocates state for a new session.
48-
Status CreateSession(const string& session, const ServerDef& server_def);
48+
Status CreateSession(const string& session, const ServerDef& server_def,
49+
bool isolate_session_state);
4950

5051
// Locates the worker session for a given session handle
5152
WorkerSession* WorkerSessionForSession(const string& session);
@@ -71,6 +72,7 @@ class SessionMgr {
7172
// legacy_session_ is deleted. Further, we must ensure that WorkerSession's
7273
// device_mgr is deleted after WorkerSession's graph_mgr.
7374

75+
std::unique_ptr<WorkerCacheInterface> default_worker_cache_;
7476
WorkerSession legacy_session_;
7577

7678
const WorkerCacheFactory worker_cache_factory_;

tensorflow/core/distributed_runtime/session_mgr_test.cc

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,36 @@ limitations under the License.
2222

2323
namespace tensorflow {
2424

25+
class FakeDevice : public Device {
26+
private:
27+
explicit FakeDevice(const DeviceAttributes& device_attributes)
28+
: Device(nullptr, device_attributes) {}
29+
30+
public:
31+
Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }
32+
33+
Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
34+
35+
static std::unique_ptr<Device> MakeCPU(const string& name) {
36+
DeviceAttributes device_attributes;
37+
device_attributes.set_name(name);
38+
device_attributes.set_device_type(DeviceType("FakeCPU").type());
39+
return std::unique_ptr<Device>(new FakeDevice(device_attributes));
40+
}
41+
};
42+
2543
class SessionMgrTest : public ::testing::Test {
2644
protected:
2745
SessionMgrTest()
28-
: mgr_(&env_, "/job:mnist/replica:0/task:0",
29-
std::unique_ptr<WorkerCacheInterface>(),
30-
factory_),
31-
legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {}
46+
: device_(FakeDevice::MakeCPU(
47+
"/job:mnist/replica:0/task:0/device:fakecpu:0")),
48+
mgr_(&env_, "/job:mnist/replica:0/task:0",
49+
std::unique_ptr<WorkerCacheInterface>(), factory_),
50+
legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {
51+
env_.local_devices = {device_.get()};
52+
}
3253

54+
std::unique_ptr<Device> device_;
3355
WorkerEnv env_;
3456
SessionMgr::WorkerCacheFactory factory_ =
3557
[](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
@@ -42,14 +64,48 @@ class SessionMgrTest : public ::testing::Test {
4264

4365
TEST_F(SessionMgrTest, CreateSessionSimple) {
4466
ServerDef server_def;
67+
server_def.set_job_name("worker");
68+
server_def.set_task_index(3);
69+
4570
string session_handle = "test_session_handle";
46-
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
71+
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def, true));
4772
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
4873
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
4974
EXPECT_NE(mgr_.LegacySession(), session);
5075
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
5176
}
5277

78+
TEST_F(SessionMgrTest, CreateSessionIsolateSessionState) {
79+
ServerDef server_def;
80+
server_def.set_job_name("worker");
81+
server_def.set_task_index(3);
82+
83+
TF_EXPECT_OK(mgr_.CreateSession("handle_1", server_def, false));
84+
WorkerSession* session_1 = mgr_.WorkerSessionForSession("handle_1");
85+
std::vector<Device*> devices_1 = session_1->device_mgr->ListDevices();
86+
EXPECT_EQ(1, devices_1.size());
87+
88+
TF_EXPECT_OK(mgr_.CreateSession("handle_2", server_def, false));
89+
WorkerSession* session_2 = mgr_.WorkerSessionForSession("handle_2");
90+
std::vector<Device*> devices_2 = session_2->device_mgr->ListDevices();
91+
EXPECT_EQ(1, devices_2.size());
92+
93+
TF_EXPECT_OK(mgr_.CreateSession("handle_3", server_def, true));
94+
WorkerSession* session_3 = mgr_.WorkerSessionForSession("handle_3");
95+
std::vector<Device*> devices_3 = session_3->device_mgr->ListDevices();
96+
EXPECT_EQ(1, devices_3.size());
97+
98+
TF_EXPECT_OK(mgr_.CreateSession("handle_4", server_def, true));
99+
WorkerSession* session_4 = mgr_.WorkerSessionForSession("handle_4");
100+
std::vector<Device*> devices_4 = session_4->device_mgr->ListDevices();
101+
EXPECT_EQ(1, devices_4.size());
102+
103+
EXPECT_EQ(devices_1[0]->resource_manager(), devices_2[0]->resource_manager());
104+
EXPECT_NE(devices_1[0]->resource_manager(), devices_3[0]->resource_manager());
105+
EXPECT_NE(devices_1[0]->resource_manager(), devices_4[0]->resource_manager());
106+
EXPECT_NE(devices_3[0]->resource_manager(), devices_4[0]->resource_manager());
107+
}
108+
53109
TEST_F(SessionMgrTest, LegacySession) {
54110
ServerDef server_def;
55111
string session_handle = "";

tensorflow/core/distributed_runtime/worker.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ void Worker::CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
4444
CreateWorkerSessionResponse* response,
4545
StatusCallback done) {
4646
Status s = env_->session_mgr->CreateSession(request->session_handle(),
47-
request->server_def());
47+
request->server_def(),
48+
request->isolate_session_state());
4849
done(s);
4950
}
5051

0 commit comments

Comments
 (0)