11use super :: potentially_plural_count;
22use crate :: errors:: LifetimesOrBoundsMismatchOnTrait ;
3- use rustc_data_structures:: fx:: FxHashSet ;
3+ use hir:: def_id:: DefId ;
4+ use rustc_data_structures:: fx:: { FxHashMap , FxHashSet } ;
45use rustc_errors:: { pluralize, struct_span_err, Applicability , DiagnosticId , ErrorGuaranteed } ;
56use rustc_hir as hir;
67use rustc_hir:: def:: { DefKind , Res } ;
78use rustc_hir:: intravisit;
89use rustc_hir:: { GenericParamKind , ImplItemKind , TraitItemKind } ;
910use rustc_infer:: infer:: outlives:: env:: OutlivesEnvironment ;
11+ use rustc_infer:: infer:: type_variable:: { TypeVariableOrigin , TypeVariableOriginKind } ;
1012use rustc_infer:: infer:: { self , TyCtxtInferExt } ;
1113use rustc_infer:: traits:: util;
1214use rustc_middle:: ty:: error:: { ExpectedFound , TypeError } ;
1315use rustc_middle:: ty:: subst:: { InternalSubsts , Subst } ;
1416use rustc_middle:: ty:: util:: ExplicitSelf ;
15- use rustc_middle:: ty:: { self , DefIdTree } ;
17+ use rustc_middle:: ty:: {
18+ self , DefIdTree , Ty , TypeFoldable , TypeFolder , TypeSuperFoldable , TypeVisitable ,
19+ } ;
1620use rustc_middle:: ty:: { GenericParamDefKind , ToPredicate , TyCtxt } ;
1721use rustc_span:: Span ;
1822use rustc_trait_selection:: traits:: error_reporting:: InferCtxtExt ;
@@ -64,10 +68,7 @@ pub(crate) fn compare_impl_method<'tcx>(
6468 return ;
6569 }
6670
67- if let Err ( _) = compare_predicate_entailment ( tcx, impl_m, impl_m_span, trait_m, impl_trait_ref)
68- {
69- return ;
70- }
71+ tcx. ensure ( ) . compare_predicates_and_trait_impl_trait_tys ( impl_m. def_id ) ;
7172}
7273
7374/// This function is best explained by example. Consider a trait:
@@ -136,13 +137,15 @@ pub(crate) fn compare_impl_method<'tcx>(
136137///
137138/// Finally we register each of these predicates as an obligation and check that
138139/// they hold.
139- fn compare_predicate_entailment < ' tcx > (
140+ pub ( super ) fn compare_predicates_and_trait_impl_trait_tys < ' tcx > (
140141 tcx : TyCtxt < ' tcx > ,
141- impl_m : & ty:: AssocItem ,
142- impl_m_span : Span ,
143- trait_m : & ty:: AssocItem ,
144- impl_trait_ref : ty:: TraitRef < ' tcx > ,
145- ) -> Result < ( ) , ErrorGuaranteed > {
142+ def_id : DefId ,
143+ ) -> Result < & ' tcx FxHashMap < DefId , Ty < ' tcx > > , ErrorGuaranteed > {
144+ let impl_m = tcx. opt_associated_item ( def_id) . unwrap ( ) ;
145+ let impl_m_span = tcx. def_span ( def_id) ;
146+ let trait_m = tcx. opt_associated_item ( impl_m. trait_item_def_id . unwrap ( ) ) . unwrap ( ) ;
147+ let impl_trait_ref = tcx. impl_trait_ref ( impl_m. impl_container ( tcx) . unwrap ( ) ) . unwrap ( ) ;
148+
146149 let trait_to_impl_substs = impl_trait_ref. substs ;
147150
148151 // This node-id should be used for the `body_id` field on each
@@ -161,6 +164,7 @@ fn compare_predicate_entailment<'tcx>(
161164 kind : impl_m. kind ,
162165 } ,
163166 ) ;
167+ let return_span = tcx. hir ( ) . fn_decl_by_hir_id ( impl_m_hir_id) . unwrap ( ) . output . span ( ) ;
164168
165169 // Create mapping from impl to placeholder.
166170 let impl_to_placeholder_substs = InternalSubsts :: identity_for_item ( tcx, impl_m. def_id ) ;
@@ -266,6 +270,13 @@ fn compare_predicate_entailment<'tcx>(
266270
267271 let trait_sig = tcx. bound_fn_sig ( trait_m. def_id ) . subst ( tcx, trait_to_placeholder_substs) ;
268272 let trait_sig = tcx. liberate_late_bound_regions ( impl_m. def_id , trait_sig) ;
273+ let mut collector =
274+ ImplTraitInTraitCollector :: new ( & ocx, return_span, param_env, impl_m_hir_id) ;
275+ // FIXME(RPITIT): This should only be needed on the output type, but
276+ // RPITIT placeholders shouldn't show up anywhere except for there,
277+ // so I think this is fine.
278+ let trait_sig = trait_sig. fold_with ( & mut collector) ;
279+
269280 // Next, add all inputs and output as well-formed tys. Importantly,
270281 // we have to do this before normalization, since the normalized ty may
271282 // not contain the input parameters. See issue #87748.
@@ -391,30 +402,6 @@ fn compare_predicate_entailment<'tcx>(
391402 return Err ( diag. emit ( ) ) ;
392403 }
393404
394- // Check that an impl's fn return satisfies the bounds of the
395- // FIXME(RPITIT): Generalize this to nested impl traits
396- if let ty:: Projection ( proj) = tcx. fn_sig ( trait_m. def_id ) . skip_binder ( ) . output ( ) . kind ( )
397- && tcx. def_kind ( proj. item_def_id ) == DefKind :: ImplTraitPlaceholder
398- {
399- let return_span = tcx. hir ( ) . fn_decl_by_hir_id ( impl_m_hir_id) . unwrap ( ) . output . span ( ) ;
400-
401- for ( predicate, span) in tcx
402- . bound_explicit_item_bounds ( proj. item_def_id )
403- . transpose_iter ( )
404- . map ( |pred| pred. map_bound ( |pred| * pred) . subst ( tcx, trait_to_placeholder_substs) )
405- {
406- ocx. register_obligation ( traits:: Obligation :: new (
407- traits:: ObligationCause :: new (
408- return_span,
409- impl_m_hir_id,
410- ObligationCauseCode :: BindingObligation ( proj. item_def_id , span) ,
411- ) ,
412- param_env,
413- predicate,
414- ) ) ;
415- }
416- }
417-
418405 // Check that all obligations are satisfied by the implementation's
419406 // version.
420407 let errors = ocx. select_all_or_error ( ) ;
@@ -435,10 +422,96 @@ fn compare_predicate_entailment<'tcx>(
435422 & outlives_environment,
436423 ) ;
437424
438- Ok ( ( ) )
425+ let mut collected_tys = FxHashMap :: default ( ) ;
426+ for ( def_id, ty) in collector. types {
427+ match infcx. fully_resolve ( ty) {
428+ Ok ( ty) => {
429+ collected_tys. insert ( def_id, ty) ;
430+ }
431+ Err ( err) => {
432+ tcx. sess . delay_span_bug (
433+ return_span,
434+ format ! ( "could not fully resolve: {ty} => {err:?}" ) ,
435+ ) ;
436+ collected_tys. insert ( def_id, tcx. ty_error ( ) ) ;
437+ }
438+ }
439+ }
440+
441+ Ok ( & * tcx. arena . alloc ( collected_tys) )
439442 } )
440443}
441444
445+ struct ImplTraitInTraitCollector < ' a , ' tcx > {
446+ ocx : & ' a ObligationCtxt < ' a , ' tcx > ,
447+ types : FxHashMap < DefId , Ty < ' tcx > > ,
448+ span : Span ,
449+ param_env : ty:: ParamEnv < ' tcx > ,
450+ body_id : hir:: HirId ,
451+ }
452+
453+ impl < ' a , ' tcx > ImplTraitInTraitCollector < ' a , ' tcx > {
454+ fn new (
455+ ocx : & ' a ObligationCtxt < ' a , ' tcx > ,
456+ span : Span ,
457+ param_env : ty:: ParamEnv < ' tcx > ,
458+ body_id : hir:: HirId ,
459+ ) -> Self {
460+ ImplTraitInTraitCollector { ocx, types : FxHashMap :: default ( ) , span, param_env, body_id }
461+ }
462+ }
463+
464+ impl < ' tcx > TypeFolder < ' tcx > for ImplTraitInTraitCollector < ' _ , ' tcx > {
465+ fn tcx < ' a > ( & ' a self ) -> TyCtxt < ' tcx > {
466+ self . ocx . infcx . tcx
467+ }
468+
469+ fn fold_ty ( & mut self , ty : Ty < ' tcx > ) -> Ty < ' tcx > {
470+ if let ty:: Projection ( proj) = ty. kind ( )
471+ && self . tcx ( ) . def_kind ( proj. item_def_id ) == DefKind :: ImplTraitPlaceholder
472+ {
473+ if let Some ( ty) = self . types . get ( & proj. item_def_id ) {
474+ return * ty;
475+ }
476+ //FIXME(RPITIT): Deny nested RPITIT in substs too
477+ if proj. substs . has_escaping_bound_vars ( ) {
478+ bug ! ( "FIXME(RPITIT): error here" ) ;
479+ }
480+ // Replace with infer var
481+ let infer_ty = self . ocx . infcx . next_ty_var ( TypeVariableOrigin {
482+ span : self . span ,
483+ kind : TypeVariableOriginKind :: MiscVariable ,
484+ } ) ;
485+ self . types . insert ( proj. item_def_id , infer_ty) ;
486+ // Recurse into bounds
487+ for pred in self . tcx ( ) . bound_explicit_item_bounds ( proj. item_def_id ) . transpose_iter ( ) {
488+ let pred_span = pred. 0 . 1 ;
489+
490+ let pred = pred. map_bound ( |( pred, _) | * pred) . subst ( self . tcx ( ) , proj. substs ) ;
491+ let pred = pred. fold_with ( self ) ;
492+ let pred = self . ocx . normalize (
493+ ObligationCause :: misc ( self . span , self . body_id ) ,
494+ self . param_env ,
495+ pred,
496+ ) ;
497+
498+ self . ocx . register_obligation ( traits:: Obligation :: new (
499+ ObligationCause :: new (
500+ self . span ,
501+ self . body_id ,
502+ ObligationCauseCode :: BindingObligation ( proj. item_def_id , pred_span) ,
503+ ) ,
504+ self . param_env ,
505+ pred,
506+ ) ) ;
507+ }
508+ infer_ty
509+ } else {
510+ ty. super_fold_with ( self )
511+ }
512+ }
513+ }
514+
442515fn check_region_bounds_on_impl_item < ' tcx > (
443516 tcx : TyCtxt < ' tcx > ,
444517 impl_m : & ty:: AssocItem ,
0 commit comments