@@ -19,7 +19,7 @@ use rustls::{
1919use crate :: error:: Error ;
2020use crate :: io:: ReadBuf ;
2121use crate :: net:: tls:: util:: StdSocket ;
22- use crate :: net:: tls:: TlsConfig ;
22+ use crate :: net:: tls:: { RawTlsConfig , TlsConfig } ;
2323use crate :: net:: Socket ;
2424
2525pub struct RustlsSocket < S : Socket > {
@@ -87,100 +87,125 @@ impl<S: Socket> Socket for RustlsSocket<S> {
8787 }
8888}
8989
90- pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
91- where
92- S : Socket ,
93- {
94- #[ cfg( all(
95- feature = "_tls-rustls-aws-lc-rs" ,
96- not( feature = "_tls-rustls-ring-webpki" ) ,
97- not( feature = "_tls-rustls-ring-native-roots" )
98- ) ) ]
99- let provider = Arc :: new ( rustls:: crypto:: aws_lc_rs:: default_provider ( ) ) ;
100- #[ cfg( any(
101- feature = "_tls-rustls-ring-webpki" ,
102- feature = "_tls-rustls-ring-native-roots"
103- ) ) ]
104- let provider = Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) ;
105-
106- // Unwrapping is safe here because we use a default provider.
107- let config = ClientConfig :: builder_with_provider ( provider. clone ( ) )
108- . with_safe_default_protocol_versions ( )
109- . unwrap ( ) ;
110-
111- // authentication using user's key and its associated certificate
112- let user_auth = match ( tls_config. client_cert_path , tls_config. client_key_path ) {
113- ( Some ( cert_path) , Some ( key_path) ) => {
114- let cert_chain = certs_from_pem ( cert_path. data ( ) . await ?) ?;
115- let key_der = private_key_from_pem ( key_path. data ( ) . await ?) ?;
116- Some ( ( cert_chain, key_der) )
117- }
118- ( None , None ) => None ,
119- ( _, _) => {
120- return Err ( Error :: Configuration (
121- "user auth key and certs must be given together" . into ( ) ,
122- ) )
123- }
124- } ;
90+ impl TlsConfig < ' _ > {
91+ async fn rustls_config ( & self ) -> crate :: Result < ( rustls:: ClientConfig , & str ) , Error > {
92+ let RawTlsConfig {
93+ accept_invalid_certs,
94+ accept_invalid_hostnames,
95+ hostname,
96+ root_cert,
97+ client_cert,
98+ client_key,
99+ } = match self {
100+ TlsConfig :: RawTlsConfig ( raw) => raw,
101+ TlsConfig :: PrebuiltRustls { config, hostname } => {
102+ return Ok ( ( ( * config) . to_owned ( ) , hostname) ) ;
103+ }
104+ } ;
105+
106+ #[ cfg( all(
107+ feature = "_tls-rustls-aws-lc-rs" ,
108+ not( feature = "_tls-rustls-ring-webpki" ) ,
109+ not( feature = "_tls-rustls-ring-native-roots" )
110+ ) ) ]
111+ let provider = Arc :: new ( rustls:: crypto:: aws_lc_rs:: default_provider ( ) ) ;
112+ #[ cfg( any(
113+ feature = "_tls-rustls-ring-webpki" ,
114+ feature = "_tls-rustls-ring-native-roots"
115+ ) ) ]
116+ let provider = Arc :: new ( rustls:: crypto:: ring:: default_provider ( ) ) ;
117+
118+ // Unwrapping is safe here because we use a default provider.
119+ let config = ClientConfig :: builder_with_provider ( provider. clone ( ) )
120+ . with_safe_default_protocol_versions ( )
121+ . unwrap ( ) ;
122+
123+ // authentication using user's key and its associated certificate
124+ let user_auth = match ( client_cert, client_key) {
125+ ( Some ( cert) , Some ( key) ) => {
126+ let cert_chain = certs_from_pem ( cert. data ( ) . await ?) ?;
127+ let key_der = private_key_from_pem ( key. data ( ) . await ?) ?;
128+ Some ( ( cert_chain, key_der) )
129+ }
130+ ( None , None ) => None ,
131+ ( _, _) => {
132+ return Err ( Error :: Configuration (
133+ "user auth key and certs must be given together" . into ( ) ,
134+ ) )
135+ }
136+ } ;
125137
126- let config = if tls_config. accept_invalid_certs {
127- if let Some ( user_auth) = user_auth {
128- config
129- . dangerous ( )
130- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
131- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
132- . map_err ( Error :: tls) ?
138+ let config = if * accept_invalid_certs {
139+ if let Some ( user_auth) = user_auth {
140+ config
141+ . dangerous ( )
142+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
143+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
144+ . map_err ( Error :: tls) ?
145+ } else {
146+ config
147+ . dangerous ( )
148+ . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
149+ . with_no_client_auth ( )
150+ }
133151 } else {
134- config
135- . dangerous ( )
136- . with_custom_certificate_verifier ( Arc :: new ( DummyTlsVerifier { provider } ) )
137- . with_no_client_auth ( )
138- }
139- } else {
140- let mut cert_store = import_root_certs ( ) ;
152+ let mut cert_store = import_root_certs ( ) ;
141153
142- if let Some ( ca) = tls_config . root_cert_path {
143- let data = ca. data ( ) . await ?;
154+ if let Some ( ca) = root_cert {
155+ let data = ca. data ( ) . await ?;
144156
145- for result in CertificateDer :: pem_slice_iter ( & data) {
146- let Ok ( cert) = result else {
147- return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
148- } ;
157+ for result in CertificateDer :: pem_slice_iter ( & data) {
158+ let Ok ( cert) = result else {
159+ return Err ( Error :: Tls ( format ! ( "Invalid certificate {ca}" ) . into ( ) ) ) ;
160+ } ;
149161
150- cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
162+ cert_store. add ( cert) . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
163+ }
151164 }
152- }
153-
154- if tls_config. accept_invalid_hostnames {
155- let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
156- . build ( )
157- . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
158165
159- if let Some ( user_auth) = user_auth {
166+ if * accept_invalid_hostnames {
167+ let verifier = WebPkiServerVerifier :: builder ( Arc :: new ( cert_store) )
168+ . build ( )
169+ . map_err ( |err| Error :: Tls ( err. into ( ) ) ) ?;
170+
171+ if let Some ( user_auth) = user_auth {
172+ config
173+ . dangerous ( )
174+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
175+ verifier,
176+ } ) )
177+ . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
178+ . map_err ( Error :: tls) ?
179+ } else {
180+ config
181+ . dangerous ( )
182+ . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier {
183+ verifier,
184+ } ) )
185+ . with_no_client_auth ( )
186+ }
187+ } else if let Some ( user_auth) = user_auth {
160188 config
161- . dangerous ( )
162- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
189+ . with_root_certificates ( cert_store)
163190 . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
164191 . map_err ( Error :: tls) ?
165192 } else {
166193 config
167- . dangerous ( )
168- . with_custom_certificate_verifier ( Arc :: new ( NoHostnameTlsVerifier { verifier } ) )
194+ . with_root_certificates ( cert_store)
169195 . with_no_client_auth ( )
170196 }
171- } else if let Some ( user_auth) = user_auth {
172- config
173- . with_root_certificates ( cert_store)
174- . with_client_auth_cert ( user_auth. 0 , user_auth. 1 )
175- . map_err ( Error :: tls) ?
176- } else {
177- config
178- . with_root_certificates ( cert_store)
179- . with_no_client_auth ( )
180- }
181- } ;
197+ } ;
198+
199+ Ok ( ( config, hostname) )
200+ }
201+ }
182202
183- let host = ServerName :: try_from ( tls_config. hostname . to_owned ( ) ) . map_err ( Error :: tls) ?;
203+ pub async fn handshake < S > ( socket : S , tls_config : TlsConfig < ' _ > ) -> Result < RustlsSocket < S > , Error >
204+ where
205+ S : Socket ,
206+ {
207+ let ( config, hostname) = tls_config. rustls_config ( ) . await ?;
208+ let host = ServerName :: try_from ( hostname. to_owned ( ) ) . map_err ( Error :: tls) ?;
184209
185210 let mut socket = RustlsSocket {
186211 inner : StdSocket :: new ( socket) ,
0 commit comments