Skip to content

Commit f16fd88

Browse files
committed
better integration of auth into default client: move the token caching logic into oauth2 module as TokenCache since it's reusable
1 parent c6f830d commit f16fd88

File tree

1 file changed

+52
-98
lines changed

1 file changed

+52
-98
lines changed

src/default_client.rs

Lines changed: 52 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
1414
use crate::Error;
1515
use crate::client_trait::*;
16-
use crate::oauth2::Authorization;
17-
use std::sync::{Arc, RwLock};
16+
use crate::oauth2::{Authorization, TokenCache};
17+
use std::sync::Arc;
1818

1919
const USER_AGENT: &str = concat!("Dropbox-APIv2-Rust/", env!("CARGO_PKG_VERSION"));
2020

2121
macro_rules! forward_noauth_request {
22-
($self:ident, $inner:expr, $token:expr, $path_root:expr, $team_select:expr) => {
22+
($self:ident, $inner:expr, $path_root:expr) => {
2323
fn request(
2424
&$self,
2525
endpoint: Endpoint,
@@ -32,13 +32,13 @@ macro_rules! forward_noauth_request {
3232
range_end: Option<u64>,
3333
) -> crate::Result<HttpRequestResultRaw> {
3434
$inner.request(endpoint, style, function, &params, params_type, body, range_start,
35-
range_end, $token, $path_root, $team_select)
35+
range_end, None, $path_root, None)
3636
}
3737
}
3838
}
3939

40-
macro_rules! forward_auth_request {
41-
($self:ident, $inner:expr, $path_root:expr, $team_select:expr) => {
40+
macro_rules! forward_authed_request {
41+
($self:ident, $tokens:expr, $inner:expr, $path_root:expr, $team_select:expr) => {
4242
fn request(
4343
&$self,
4444
endpoint: Endpoint,
@@ -50,8 +50,32 @@ macro_rules! forward_auth_request {
5050
range_start: Option<u64>,
5151
range_end: Option<u64>,
5252
) -> crate::Result<HttpRequestResultRaw> {
53-
$inner.request(endpoint, style, function, &params, params_type, body, range_start,
54-
range_end, $path_root, $team_select)
53+
let mut token = $tokens.get_token(TokenUpdateClient { inner: &$inner })?;
54+
55+
let mut retried = false;
56+
loop {
57+
let result = $inner.request(endpoint, style, function, &params, params_type, body,
58+
range_start, range_end, Some(&token), $path_root, $team_select);
59+
60+
if retried {
61+
break result;
62+
}
63+
64+
if let Err(crate::Error::InvalidToken(msg)) = &result {
65+
if msg == "expired_access_token" {
66+
info!("refreshing token");
67+
let old_token = token;
68+
token = $tokens.update_token(
69+
TokenUpdateClient { inner: &$inner },
70+
old_token,
71+
)?;
72+
retried = true;
73+
continue;
74+
}
75+
}
76+
77+
break result;
78+
}
5579
}
5680
}
5781
}
@@ -76,15 +100,23 @@ macro_rules! impl_set_path_root {
76100

77101
/// Default HTTP client using User authorization.
78102
pub struct UserAuthDefaultClient {
79-
inner: UreqAuthClient,
103+
inner: UreqClient,
104+
tokens: Arc<TokenCache>,
80105
path_root: Option<String>, // a serialized PathRoot enum
81106
}
82107

83108
impl UserAuthDefaultClient {
84-
/// Create a new client using the given OAuth2 token.
109+
/// Create a new client using the given OAuth2 authorization.
85110
pub fn new(auth: Authorization) -> Self {
111+
Self::from_token_cache(Arc::new(TokenCache::new(auth)))
112+
}
113+
114+
/// Create a new client from a [`TokenCache`], which lets you share the same tokens between
115+
/// multiple clients.
116+
pub fn from_token_cache(tokens: Arc<TokenCache>) -> Self {
86117
Self {
87-
inner: UreqAuthClient::new(auth),
118+
inner: UreqClient::default(),
119+
tokens,
88120
path_root: None,
89121
}
90122
}
@@ -93,23 +125,25 @@ impl UserAuthDefaultClient {
93125
}
94126

95127
impl HttpClient for UserAuthDefaultClient {
96-
forward_auth_request! { self, self.inner, self.path_root.as_deref(), None }
128+
forward_authed_request! { self, self.tokens, self.inner, self.path_root.as_deref(), None }
97129
}
98130

99131
impl UserAuthClient for UserAuthDefaultClient {}
100132

101133
/// Default HTTP client using Team authorization.
102134
pub struct TeamAuthDefaultClient {
103-
inner: UreqAuthClient,
135+
inner: UreqClient,
136+
tokens: Arc<TokenCache>,
104137
path_root: Option<String>, // a serialized PathRoot enum
105138
team_select: Option<TeamSelect>,
106139
}
107140

108141
impl TeamAuthDefaultClient {
109142
/// Create a new client using the given OAuth2 token, with no user/admin context selected.
110-
pub fn new(auth: Authorization) -> Self {
143+
pub fn new(tokens: impl Into<Arc<TokenCache>>) -> Self {
111144
Self {
112-
inner: UreqAuthClient::new(auth),
145+
inner: UreqClient::default(),
146+
tokens: tokens.into(),
113147
path_root: None,
114148
team_select: None,
115149
}
@@ -124,7 +158,7 @@ impl TeamAuthDefaultClient {
124158
}
125159

126160
impl HttpClient for TeamAuthDefaultClient {
127-
forward_auth_request! { self, self.inner, self.path_root.as_deref(), self.team_select.as_ref() }
161+
forward_authed_request! { self, self.tokens, self.inner, self.path_root.as_deref(), self.team_select.as_ref() }
128162
}
129163

130164
impl TeamAuthClient for TeamAuthDefaultClient {}
@@ -141,7 +175,7 @@ impl NoauthDefaultClient {
141175
}
142176

143177
impl HttpClient for NoauthDefaultClient {
144-
forward_noauth_request! { self, self.inner, None, self.path_root.as_deref(), None }
178+
forward_noauth_request! { self, self.inner, self.path_root.as_deref() }
145179
}
146180

147181
impl NoauthClient for NoauthDefaultClient {}
@@ -153,91 +187,11 @@ struct TokenUpdateClient<'a> {
153187
}
154188

155189
impl<'a> HttpClient for TokenUpdateClient<'a> {
156-
forward_noauth_request! { self, self.inner, None, None, None }
190+
forward_noauth_request! { self, self.inner, None }
157191
}
158192

159193
impl<'a> NoauthClient for TokenUpdateClient<'a> {}
160194

161-
struct UreqAuthClient {
162-
inner: UreqClient,
163-
auth: RwLock<(Authorization, Arc<String>)>,
164-
}
165-
166-
impl UreqAuthClient {
167-
fn new(auth: Authorization) -> Self {
168-
Self {
169-
inner: UreqClient::default(),
170-
auth: RwLock::new((auth, Arc::new(String::new()))), // obtain a token on first request
171-
}
172-
}
173-
}
174-
175-
impl UreqAuthClient {
176-
#[allow(clippy::too_many_arguments)]
177-
fn request(
178-
&self,
179-
endpoint: Endpoint,
180-
style: Style,
181-
function: &str,
182-
params: &str,
183-
params_type: ParamsType,
184-
body: Option<&[u8]>,
185-
range_start: Option<u64>,
186-
range_end: Option<u64>,
187-
path_root: Option<&str>,
188-
team_select: Option<&TeamSelect>,
189-
) -> crate::Result<HttpRequestResultRaw> {
190-
let mut token: Arc<String> = {
191-
let read = self.auth.read().unwrap();
192-
if read.1.is_empty() {
193-
drop(read);
194-
let mut write = self.auth.write().unwrap();
195-
if write.1.is_empty() {
196-
// Check again; it's possible someone else updated it while
197-
// we were unlocked.
198-
info!("Requesting initial OAuth2 token");
199-
let client = TokenUpdateClient { inner: &self.inner };
200-
write.1 = Arc::new(write.0.obtain_access_token(client)?);
201-
}
202-
Arc::clone(&write.1)
203-
} else {
204-
Arc::clone(&read.1)
205-
}
206-
};
207-
208-
let mut retried = false;
209-
loop {
210-
let result = self.inner.request(
211-
endpoint, style, function, params, params_type, body, range_start, range_end,
212-
Some(&token), path_root, team_select);
213-
214-
if retried {
215-
break result;
216-
}
217-
218-
if let Err(crate::Error::InvalidToken(msg)) = &result {
219-
if msg == "expired_access_token" {
220-
let mut write = self.auth.write().unwrap();
221-
// Check if the token changed while we were unlocked; only update it if it
222-
// didn't.
223-
if write.1 == token {
224-
info!("Refreshing OAuth2 token");
225-
let client = TokenUpdateClient { inner: &self.inner };
226-
token = Arc::new(write.0.obtain_access_token(client)?);
227-
write.1 = Arc::clone(&token);
228-
} else {
229-
token = Arc::clone(&write.1);
230-
}
231-
retried = true;
232-
continue;
233-
}
234-
}
235-
236-
break result;
237-
}
238-
}
239-
}
240-
241195
#[derive(Debug, Default)]
242196
struct UreqClient {}
243197

0 commit comments

Comments
 (0)