@@ -72,6 +72,23 @@ def concat_to_tensor(
7272
7373 Returns:
7474 Tuple[paddle.Tensor, ...]: Concatenated tensor.
75+
76+ Examples:
77+ >>> import paddle
78+ >>> import ppsci
79+ >>> model = ppsci.arch.Arch()
80+ >>> # fetch one tensor
81+ >>> out = model.concat_to_tensor({'x':paddle.rand([64, 64, 1])}, ('x',))
82+ >>> print(out.dtype, out.shape)
83+ paddle.float32 [64, 64, 1]
84+ >>> # fetch more tensors
85+ >>> out = model.concat_to_tensor(
86+ ... {'x1':paddle.rand([64, 64, 1]), 'x2':paddle.rand([64, 64, 1])},
87+ ... ('x1', 'x2'),
88+ ... axis=2)
89+ >>> print(out.dtype, out.shape)
90+ paddle.float32 [64, 64, 2]
91+
7592 """
7693 if len (keys ) == 1 :
7794 return data_dict [keys [0 ]]
@@ -90,6 +107,23 @@ def split_to_dict(
90107
91108 Returns:
92109 Dict[str, paddle.Tensor]: Dict contains tensor.
110+
111+ Examples:
112+ >>> import paddle
113+ >>> import ppsci
114+ >>> model = ppsci.arch.Arch()
115+ >>> # split one tensor
116+ >>> out = model.split_to_dict(paddle.rand([64, 64, 1]), ('x',))
117+ >>> for k, v in out.items():
118+ ... print(f"{k} {v.dtype} {v.shape}")
119+ x paddle.float32 [64, 64, 1]
120+ >>> # split more tensors
121+ >>> out = model.split_to_dict(paddle.rand([64, 64, 2]), ('x1', 'x2'), axis=2)
122+ >>> for k, v in out.items():
123+ ... print(f"{k} {v.dtype} {v.shape}")
124+ x1 paddle.float32 [64, 64, 1]
125+ x2 paddle.float32 [64, 64, 1]
126+
93127 """
94128 if len (keys ) == 1 :
95129 return {keys [0 ]: data_tensor }
@@ -105,6 +139,27 @@ def register_input_transform(
105139 Args:
106140 transform (Callable[[Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
107141 Input transform of network, receive a single tensor dict and return a single tensor dict.
142+
143+ Examples:
144+ >>> import ppsci
145+ >>> def transform_in(in_):
146+ ... x = in_["x"]
147+ ... # transform input
148+ ... x_ = 2.0 * x
149+ ... input_trans = {"2x": x_}
150+ ... return input_trans
151+ >>> # `MLP` inherits from `Arch`
152+ >>> model = ppsci.arch.MLP(
153+ ... input_keys=("2x",),
154+ ... output_keys=("y",),
155+ ... num_layers=5,
156+ ... hidden_size=32)
157+ >>> model.register_input_transform(transform_in)
158+ >>> out = model({"x":paddle.rand([64, 64, 1])})
159+ >>> for k, v in out.items():
160+ ... print(f"{k} {v.dtype} {v.shape}")
161+ y paddle.float32 [64, 64, 1]
162+
108163 """
109164 self ._input_transform = transform
110165
@@ -121,18 +176,61 @@ def register_output_transform(
121176 transform (Callable[[Dict[str, paddle.Tensor], Dict[str, paddle.Tensor]], Dict[str, paddle.Tensor]]):
122177 Output transform of network, receive two single tensor dict(raw input
123178 and raw output) and return a single tensor dict(transformed output).
179+
180+ Examples:
181+ >>> import ppsci
182+ >>> def transform_out(in_, out):
183+ ... x = in_["x"]
184+ ... y = out["y"]
185+ ... u = 2.0 * x * y
186+ ... output_trans = {"u": u}
187+ ... return output_trans
188+ >>> # `MLP` inherits from `Arch`
189+ >>> model = ppsci.arch.MLP(
190+ ... input_keys=("x",),
191+ ... output_keys=("y",),
192+ ... num_layers=5,
193+ ... hidden_size=32)
194+ >>> model.register_output_transform(transform_out)
195+ >>> out = model({"x":paddle.rand([64, 64, 1])})
196+ >>> for k, v in out.items():
197+ ... print(f"{k} {v.dtype} {v.shape}")
198+ u paddle.float32 [64, 64, 1]
199+
124200 """
125201 self ._output_transform = transform
126202
127203 def freeze (self ):
128- """Freeze all parameters."""
204+ """Freeze all parameters.
205+
206+ Examples:
207+ >>> import ppsci
208+ >>> model = ppsci.arch.Arch()
209+ >>> # freeze all parameters and make model `eval`
210+ >>> model.freeze()
211+ >>> assert not model.training
212+ >>> for p in model.parameters():
213+ ... assert p.stop_gradient
214+
215+ """
129216 for param in self .parameters ():
130217 param .stop_gradient = True
131218
132219 self .eval ()
133220
134221 def unfreeze (self ):
135- """Unfreeze all parameters."""
222+ """Unfreeze all parameters.
223+
224+ Examples:
225+ >>> import ppsci
226+ >>> model = ppsci.arch.Arch()
227+ >>> # unfreeze all parameters and make model `train`
228+ >>> model.unfreeze()
229+ >>> assert model.training
230+ >>> for p in model.parameters():
231+ ... assert not p.stop_gradient
232+
233+ """
136234 for param in self .parameters ():
137235 param .stop_gradient = False
138236
0 commit comments