@@ -56,6 +56,10 @@ pub enum DiffActivity {
56
56
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
57
57
/// with it. Drop the code which updates the original input/output for maximum performance.
58
58
DualOnly ,
59
+ /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
60
+ /// with it. Drop the code which updates the original input/output for maximum performance.
61
+ /// It expects the shadow argument to be `width` times larger than the original input/output.
62
+ DualvOnly ,
59
63
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
60
64
Duplicated ,
61
65
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
@@ -139,6 +143,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
139
143
activity == DiffActivity :: Dual
140
144
|| activity == DiffActivity :: Dualv
141
145
|| activity == DiffActivity :: DualOnly
146
+ || activity == DiffActivity :: DualvOnly
142
147
|| activity == DiffActivity :: Const
143
148
}
144
149
DiffMode :: Reverse => {
@@ -161,7 +166,7 @@ pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
161
166
if matches ! ( activity, Const ) {
162
167
return true ;
163
168
}
164
- if matches ! ( activity, Dual | DualOnly | Dualv ) {
169
+ if matches ! ( activity, Dual | DualOnly | Dualv | DualvOnly ) {
165
170
return true ;
166
171
}
167
172
// FIXME(ZuseZ4) We should make this more robust to also
@@ -178,7 +183,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
178
183
DiffMode :: Error => false ,
179
184
DiffMode :: Source => false ,
180
185
DiffMode :: Forward => {
181
- matches ! ( activity, Dual | DualOnly | Dualv | Const )
186
+ matches ! ( activity, Dual | DualOnly | Dualv | DualvOnly | Const )
182
187
}
183
188
DiffMode :: Reverse => {
184
189
matches ! ( activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const )
@@ -196,6 +201,7 @@ impl Display for DiffActivity {
196
201
DiffActivity :: Dual => write ! ( f, "Dual" ) ,
197
202
DiffActivity :: Dualv => write ! ( f, "Dualv" ) ,
198
203
DiffActivity :: DualOnly => write ! ( f, "DualOnly" ) ,
204
+ DiffActivity :: DualvOnly => write ! ( f, "DualvOnly" ) ,
199
205
DiffActivity :: Duplicated => write ! ( f, "Duplicated" ) ,
200
206
DiffActivity :: DuplicatedOnly => write ! ( f, "DuplicatedOnly" ) ,
201
207
DiffActivity :: FakeActivitySize => write ! ( f, "FakeActivitySize" ) ,
@@ -228,6 +234,7 @@ impl FromStr for DiffActivity {
228
234
"Dual" => Ok ( DiffActivity :: Dual ) ,
229
235
"Dualv" => Ok ( DiffActivity :: Dualv ) ,
230
236
"DualOnly" => Ok ( DiffActivity :: DualOnly ) ,
237
+ "DualvOnly" => Ok ( DiffActivity :: DualvOnly ) ,
231
238
"Duplicated" => Ok ( DiffActivity :: Duplicated ) ,
232
239
"DuplicatedOnly" => Ok ( DiffActivity :: DuplicatedOnly ) ,
233
240
_ => Err ( ( ) ) ,
0 commit comments