Skip to content

Commit 4f191f7

Browse files
【PPSCI Doc No.12、13、14、15、16、17】ppsci.arch.Arch (#752)
* [Add] arch examples * [Change] examples * [Change] examples * [Change] register_input_transform * [Change] data with rand --------- Co-authored-by: HydrogenSulfate <490868991@qq.com>
1 parent d2c2ff1 commit 4f191f7

File tree

1 file changed

+100
-2
lines changed

1 file changed

+100
-2
lines changed

ppsci/arch/base.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,23 @@ def concat_to_tensor(
7272
7373
Returns:
7474
Tuple[paddle.Tensor, ...]: Concatenated tensor.
75+
76+
Examples:
77+
>>> import paddle
78+
>>> import ppsci
79+
>>> model = ppsci.arch.Arch()
80+
>>> # fetch one tensor
81+
>>> out = model.concat_to_tensor({'x':paddle.rand([64, 64, 1])}, ('x',))
82+
>>> print(out.dtype, out.shape)
83+
paddle.float32 [64, 64, 1]
84+
>>> # fetch more tensors
85+
>>> out = model.concat_to_tensor(
86+
... {'x1':paddle.rand([64, 64, 1]), 'x2':paddle.rand([64, 64, 1])},
87+
... ('x1', 'x2'),
88+
... axis=2)
89+
>>> print(out.dtype, out.shape)
90+
paddle.float32 [64, 64, 2]
91+
7592
"""
7693
if len(keys) == 1:
7794
return data_dict[keys[0]]
@@ -90,6 +107,23 @@ def split_to_dict(
90107
91108
Returns:
92109
Dict[str, paddle.Tensor]: Dict contains tensor.
110+
111+
Examples:
112+
>>> import paddle
113+
>>> import ppsci
114+
>>> model = ppsci.arch.Arch()
115+
>>> # split one tensor
116+
>>> out = model.split_to_dict(paddle.rand([64, 64, 1]), ('x',))
117+
>>> for k, v in out.items():
118+
... print(f"{k} {v.dtype} {v.shape}")
119+
x paddle.float32 [64, 64, 1]
120+
>>> # split more tensors
121+
>>> out = model.split_to_dict(paddle.rand([64, 64, 2]), ('x1', 'x2'), axis=2)
122+
>>> for k, v in out.items():
123+
... print(f"{k} {v.dtype} {v.shape}")
124+
x1 paddle.float32 [64, 64, 1]
125+
x2 paddle.float32 [64, 64, 1]
126+
93127
"""
94128
if len(keys) == 1:
95129
return {keys[0]: data_tensor}
@@ -105,6 +139,27 @@ def register_input_transform(
105139
Args:
106140
transform (Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
107141
Input transform of network, receive a single tensor dict and return a single tensor dict.
142+
143+
Examples:
144+
>>> import ppsci
145+
>>> def transform_in(in_):
146+
... x = in_["x"]
147+
... # transform input
148+
... x_ = 2.0 * x
149+
... input_trans = {"2x": x_}
150+
... return input_trans
151+
>>> # `MLP` inherits from `Arch`
152+
>>> model = ppsci.arch.MLP(
153+
... input_keys=("2x",),
154+
... output_keys=("y",),
155+
... num_layers=5,
156+
... hidden_size=32)
157+
>>> model.register_input_transform(transform_in)
158+
>>> out = model({"x":paddle.rand([64, 64, 1])})
159+
>>> for k, v in out.items():
160+
... print(f"{k} {v.dtype} {v.shape}")
161+
y paddle.float32 [64, 64, 1]
162+
108163
"""
109164
self._input_transform = transform
110165

@@ -121,18 +176,61 @@ def register_output_transform(
121176
transform (Callable[[Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
122177
Output transform of network, receive two single tensor dict(raw input
123178
and raw output) and return a single tensor dict(transformed output).
179+
180+
Examples:
181+
>>> import ppsci
182+
>>> def transform_out(in_, out):
183+
... x = in_["x"]
184+
... y = out["y"]
185+
... u = 2.0 * x * y
186+
... output_trans = {"u": u}
187+
... return output_trans
188+
>>> # `MLP` inherits from `Arch`
189+
>>> model = ppsci.arch.MLP(
190+
... input_keys=("x",),
191+
... output_keys=("y",),
192+
... num_layers=5,
193+
... hidden_size=32)
194+
>>> model.register_output_transform(transform_out)
195+
>>> out = model({"x":paddle.rand([64, 64, 1])})
196+
>>> for k, v in out.items():
197+
... print(f"{k} {v.dtype} {v.shape}")
198+
u paddle.float32 [64, 64, 1]
199+
124200
"""
125201
self._output_transform = transform
126202

127203
def freeze(self):
128-
"""Freeze all parameters."""
204+
"""Freeze all parameters.
205+
206+
Examples:
207+
>>> import ppsci
208+
>>> model = ppsci.arch.Arch()
209+
>>> # freeze all parameters and make model `eval`
210+
>>> model.freeze()
211+
>>> assert not model.training
212+
>>> for p in model.parameters():
213+
... assert p.stop_gradient
214+
215+
"""
129216
for param in self.parameters():
130217
param.stop_gradient = True
131218

132219
self.eval()
133220

134221
def unfreeze(self):
135-
"""Unfreeze all parameters."""
222+
"""Unfreeze all parameters.
223+
224+
Examples:
225+
>>> import ppsci
226+
>>> model = ppsci.arch.Arch()
227+
>>> # unfreeze all parameters and make model `train`
228+
>>> model.unfreeze()
229+
>>> assert model.training
230+
>>> for p in model.parameters():
231+
... assert not p.stop_gradient
232+
233+
"""
136234
for param in self.parameters():
137235
param.stop_gradient = False
138236

0 commit comments

Comments
 (0)