Skip to content

Commit 88275da

Browse files
elistevensapaszke
authored andcommitted
CUDA documentation tweaks (pytorch#858)
1 parent bd7a5ad commit 88275da

File tree

4 files changed

+41
-25
lines changed

4 files changed

+41
-25
lines changed

docs/source/notes/cuda.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,11 @@ Use nn.DataParallel instead of multiprocessing
7171

7272
Most use cases involving batched input and multiple GPUs should default to using
7373
:class:`~torch.nn.DataParallel` to utilize more than one GPU. Even with the GIL,
74-
a single python process can saturate multiple GPUs, though at very large numbers
75-
of GPUs (8+) utilization might drop. Test your use case before investing the
76-
time to develop something more complicated.
74+
a single python process can saturate multiple GPUs.
75+
76+
As of version 0.1.9, large numbers of GPUs (8+) might not be fully utilized.
77+
However, this is a known issue that is under active development. As always,
78+
test your use case.
7779

7880
There are significant caveats to using CUDA models with
7981
:mod:`~torch.multiprocessing`; unless care is taken to meet the data handling
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
2+
Serialization semantics
3+
=======================
4+
5+
Best practices
6+
--------------
7+
8+
.. _recommend-saving-models:
9+
10+
Recommended approach for saving a model
11+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
12+
13+
There are two main approaches for serializing and restoring a model.
14+
15+
The first (recommended) saves and loads only the model parameters::
16+
17+
torch.save(the_model.state_dict(), PATH)
18+
19+
Then later::
20+
21+
the_model = TheModelClass(*args, **kwargs)
22+
the_model.load_state_dict(torch.load(PATH))
23+
24+
The second saves and loads the entire model::
25+
26+
torch.save(the_model, PATH)
27+
28+
Then later::
29+
30+
the_model = torch.load(PATH))
31+
32+
However in this case, the serialized data is bound to the specific classes
33+
and the exact directory structure used, so it can break in various ways when
34+
used in other projects, or after some serious refactors.

torch/nn/parallel/data_parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class DataParallel(Module):
3535
>>> output = net(input_var)
3636
"""
3737

38+
# TODO: update notes/cuda.rst when this class handles 8+ GPUs well
3839
def __init__(self, module, device_ids=None, output_device=None):
3940
super(DataParallel, self).__init__()
4041
if device_ids is None:

torch/serialization.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -103,28 +103,7 @@ def storage_to_tensor_type(storage):
103103
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
104104
"""Saves an object to a disk file.
105105
106-
There are two main approaches for serializing and restoring a model.
107-
108-
The first (recommended) saves and loads only the model parameters::
109-
110-
torch.save(the_model.state_dict(), PATH)
111-
112-
Then later::
113-
114-
the_model = TheModelClass(*args, **kwargs)
115-
the_model.load_state_dict(torch.load(PATH))
116-
117-
The second saves and loads the entire model::
118-
119-
torch.save(the_model, PATH)
120-
121-
Then later::
122-
123-
the_model = torch.load(PATH))
124-
125-
The second relies on both the shape of the model, as well as the class
126-
definition. This results in it being more fragile, since if the source code
127-
of the class changes, the model will no longer load.
106+
See also: :ref:`recommend-saving-models`
128107
129108
Args:
130109
obj: saved object

0 commit comments

Comments
 (0)