summaryrefslogtreecommitdiff
diff options
-rw-r--r--overlord/auth/auth.go44
-rw-r--r--overlord/auth/auth_test.go80
-rw-r--r--store/store.go58
-rw-r--r--store/store_test.go61
4 files changed, 196 insertions, 47 deletions
diff --git a/overlord/auth/auth.go b/overlord/auth/auth.go
index 2f1fa46749..ad99a566a1 100644
--- a/overlord/auth/auth.go
+++ b/overlord/auth/auth.go
@@ -228,15 +228,17 @@ type DeviceAssertions interface {
}
var (
+ // ErrNoSerial indicates that a device serial is not set yet.
ErrNoSerial = errors.New("no device serial yet")
)
// An AuthContext exposes authorization data and handles its updates.
type AuthContext interface {
Device() (*DeviceState, error)
- UpdateDevice(device *DeviceState) error
- UpdateUser(user *UserState) error
+ UpdateDeviceAuth(device *DeviceState, sessionMacaroon string) (actual *DeviceState, err error)
+
+ UpdateUserAuth(user *UserState, discharges []string) (actual *UserState, err error)
StoreID(fallback string) (string, error)
@@ -265,20 +267,46 @@ func (ac *authContext) Device() (*DeviceState, error) {
return Device(ac.state)
}
-// UpdateDevice updates device in state.
-func (ac *authContext) UpdateDevice(device *DeviceState) error {
+// UpdateDeviceAuth updates the device auth details in state.
+// The last update wins but other device details are left unchanged.
+// It returns the updated device state value.
+func (ac *authContext) UpdateDeviceAuth(device *DeviceState, newSessionMacaroon string) (actual *DeviceState, err error) {
ac.state.Lock()
defer ac.state.Unlock()
- return SetDevice(ac.state, device)
+ cur, err := Device(ac.state)
+ if err != nil {
+ return nil, err
+ }
+
+ // just do it, last update wins
+ cur.SessionMacaroon = newSessionMacaroon
+ if err := SetDevice(ac.state, cur); err != nil {
+ return nil, fmt.Errorf("internal error: cannot update just read device state: %v", err)
+ }
+
+ return cur, nil
}
-// UpdateUser updates user in state.
-func (ac *authContext) UpdateUser(user *UserState) error {
+// UpdateUserAuth updates the user auth details in state.
+// The last update wins but other user details are left unchanged.
+// It returns the updated user state value.
+func (ac *authContext) UpdateUserAuth(user *UserState, newDischarges []string) (actual *UserState, err error) {
ac.state.Lock()
defer ac.state.Unlock()
- return UpdateUser(ac.state, user)
+ cur, err := User(ac.state, user.ID)
+ if err != nil {
+ return nil, err
+ }
+
+ // just do it, last update wins
+ cur.StoreDischarges = newDischarges
+ if err := UpdateUser(ac.state, cur); err != nil {
+ return nil, fmt.Errorf("internal error: cannot update just read user state: %v", err)
+ }
+
+ return cur, nil
}
// StoreID returns the store id according to system state or
diff --git a/overlord/auth/auth_test.go b/overlord/auth/auth_test.go
index 6b0d5ffdd6..4a44107c56 100644
--- a/overlord/auth/auth_test.go
+++ b/overlord/auth/auth_test.go
@@ -283,16 +283,15 @@ func (as *authSuite) TestSetDevice(c *C) {
c.Check(device, DeepEquals, &auth.DeviceState{Brand: "some-brand"})
}
-func (as *authSuite) TestAuthContextUpdateUser(c *C) {
+func (as *authSuite) TestAuthContextUpdateUserAuth(c *C) {
as.state.Lock()
user, _ := auth.NewUser(as.state, "username", "macaroon", []string{"discharge"})
as.state.Unlock()
- user.Username = "different"
- user.StoreDischarges = []string{"updated-discharge"}
+ newDischarges := []string{"updated-discharge"}
authContext := auth.NewAuthContext(as.state, nil)
- err := authContext.UpdateUser(user)
+ user, err := authContext.UpdateUserAuth(user, newDischarges)
c.Check(err, IsNil)
as.state.Lock()
@@ -300,9 +299,43 @@ func (as *authSuite) TestAuthContextUpdateUser(c *C) {
as.state.Unlock()
c.Check(err, IsNil)
c.Check(userFromState, DeepEquals, user)
+ c.Check(userFromState.Discharges, DeepEquals, []string{"discharge"})
+ c.Check(user.StoreDischarges, DeepEquals, newDischarges)
+}
+
+func (as *authSuite) TestAuthContextUpdateUserAuthOtherUpdate(c *C) {
+ as.state.Lock()
+ user, _ := auth.NewUser(as.state, "username", "macaroon", []string{"discharge"})
+ otherUpdateUser := *user
+ otherUpdateUser.Macaroon = "macaroon2"
+ otherUpdateUser.StoreDischarges = []string{"other-discharges"}
+ err := auth.UpdateUser(as.state, &otherUpdateUser)
+ as.state.Unlock()
+ c.Assert(err, IsNil)
+
+ newDischarges := []string{"updated-discharge"}
+
+ authContext := auth.NewAuthContext(as.state, nil)
+ // last discharges win
+ curUser, err := authContext.UpdateUserAuth(user, newDischarges)
+ c.Assert(err, IsNil)
+
+ as.state.Lock()
+ userFromState, err := auth.User(as.state, user.ID)
+ as.state.Unlock()
+ c.Check(err, IsNil)
+ c.Check(userFromState, DeepEquals, curUser)
+ c.Check(curUser, DeepEquals, &auth.UserState{
+ ID: user.ID,
+ Username: "username",
+ Macaroon: "macaroon2",
+ Discharges: []string{"discharge"},
+ StoreMacaroon: "macaroon",
+ StoreDischarges: newDischarges,
+ })
}
-func (as *authSuite) TestAuthContextUpdateUserInvalid(c *C) {
+func (as *authSuite) TestAuthContextUpdateUserAuthInvalid(c *C) {
as.state.Lock()
_, _ = auth.NewUser(as.state, "username", "macaroon", []string{"discharge"})
as.state.Unlock()
@@ -314,7 +347,7 @@ func (as *authSuite) TestAuthContextUpdateUserInvalid(c *C) {
}
authContext := auth.NewAuthContext(as.state, nil)
- err := authContext.UpdateUser(user)
+ _, err := authContext.UpdateUserAuth(user, nil)
c.Assert(err, ErrorMatches, "invalid user")
}
@@ -340,21 +373,50 @@ func (as *authSuite) TestAuthContextDevice(c *C) {
c.Check(deviceFromState, DeepEquals, device)
}
-func (as *authSuite) TestAuthContextUpdateDevice(c *C) {
+func (as *authSuite) TestAuthContextUpdateDeviceAuth(c *C) {
as.state.Lock()
device, err := auth.Device(as.state)
as.state.Unlock()
c.Check(err, IsNil)
c.Check(device, DeepEquals, &auth.DeviceState{})
+ sessionMacaroon := "the-device-macaroon"
+
authContext := auth.NewAuthContext(as.state, nil)
- device.SessionMacaroon = "the-device-macaroon"
- err = authContext.UpdateDevice(device)
+ device, err = authContext.UpdateDeviceAuth(device, sessionMacaroon)
c.Check(err, IsNil)
deviceFromState, err := authContext.Device()
c.Check(err, IsNil)
c.Check(deviceFromState, DeepEquals, device)
+ c.Check(deviceFromState.SessionMacaroon, DeepEquals, sessionMacaroon)
+}
+
+func (as *authSuite) TestAuthContextUpdateDeviceAuthOtherUpdate(c *C) {
+ as.state.Lock()
+ device, _ := auth.Device(as.state)
+ otherUpdateDevice := *device
+ otherUpdateDevice.SessionMacaroon = "othe-session-macaroon"
+ otherUpdateDevice.KeyID = "KEYID"
+ err := auth.SetDevice(as.state, &otherUpdateDevice)
+ as.state.Unlock()
+ c.Check(err, IsNil)
+
+ sessionMacaroon := "the-device-macaroon"
+
+ authContext := auth.NewAuthContext(as.state, nil)
+ curDevice, err := authContext.UpdateDeviceAuth(device, sessionMacaroon)
+ c.Assert(err, IsNil)
+
+ as.state.Lock()
+ deviceFromState, err := auth.Device(as.state)
+ as.state.Unlock()
+ c.Check(err, IsNil)
+ c.Check(deviceFromState, DeepEquals, curDevice)
+ c.Check(curDevice, DeepEquals, &auth.DeviceState{
+ KeyID: "KEYID",
+ SessionMacaroon: sessionMacaroon,
+ })
}
func (as *authSuite) TestAuthContextStoreIDFallback(c *C) {
diff --git a/store/store.go b/store/store.go
index 2eaa7c1cb8..9a715a649a 100644
--- a/store/store.go
+++ b/store/store.go
@@ -367,52 +367,53 @@ func authenticateUser(r *http.Request, user *auth.UserState) {
r.Header.Set("Authorization", buf.String())
}
-// refreshMacaroon will request a refreshed discharge macaroon for the user
-func refreshMacaroon(user *auth.UserState) error {
+// refreshDischarges will request refreshed discharge macaroons for the user
+func refreshDischarges(user *auth.UserState) ([]string, error) {
+ newDischarges := make([]string, len(user.StoreDischarges))
for i, d := range user.StoreDischarges {
discharge, err := MacaroonDeserialize(d)
if err != nil {
- return err
+ return nil, err
}
- if discharge.Location() == UbuntuoneLocation {
- refreshedDischarge, err := RefreshDischargeMacaroon(d)
- if err != nil {
- return err
- }
- user.StoreDischarges[i] = refreshedDischarge
+ if discharge.Location() != UbuntuoneLocation {
+ newDischarges[i] = d
+ continue
}
+
+ refreshedDischarge, err := RefreshDischargeMacaroon(d)
+ if err != nil {
+ return nil, err
+ }
+ newDischarges[i] = refreshedDischarge
}
- return nil
+ return newDischarges, nil
}
// refreshUser will refresh user discharge macaroon and update state
func (s *Store) refreshUser(user *auth.UserState) error {
- err := refreshMacaroon(user)
+ newDischarges, err := refreshDischarges(user)
if err != nil {
return err
}
if s.authContext != nil {
- err = s.authContext.UpdateUser(user)
+ curUser, err := s.authContext.UpdateUserAuth(user, newDischarges)
if err != nil {
return err
}
+ // update in place
+ *user = *curUser
}
return nil
}
// refreshDeviceSession will set or refresh the device session in the state
-func (s *Store) refreshDeviceSession() error {
+func (s *Store) refreshDeviceSession(device *auth.DeviceState) error {
if s.authContext == nil {
return fmt.Errorf("internal error: no authContext")
}
- device, err := s.authContext.Device()
- if err != nil {
- return err
- }
-
nonce, err := RequestStoreDeviceNonce()
if err != nil {
return err
@@ -428,11 +429,12 @@ func (s *Store) refreshDeviceSession() error {
return err
}
- device.SessionMacaroon = session
- err = s.authContext.UpdateDevice(device)
+ curDevice, err := s.authContext.UpdateDeviceAuth(device, session)
if err != nil {
return err
}
+ // update in place
+ *device = *curDevice
return nil
}
@@ -492,7 +494,15 @@ func (s *Store) doRequest(client *http.Client, reqOptions *requestOptions, user
}
if strings.Contains(wwwAuth, "refresh_device_session=1") {
// refresh device session
- err = s.refreshDeviceSession()
+ if s.authContext == nil {
+ return nil, fmt.Errorf("internal error: no authContext")
+ }
+ device, err := s.authContext.Device()
+ if err != nil {
+ return nil, err
+ }
+
+ err = s.refreshDeviceSession(device)
if err != nil {
return nil, err
}
@@ -526,8 +536,10 @@ func (s *Store) newRequest(reqOptions *requestOptions, user *auth.UserState) (*h
if err != nil {
return nil, err
}
- if device.SessionMacaroon == "" {
- err = s.refreshDeviceSession()
+ // we don't have a session yet but have a serial, try
+ // to get a session
+ if device.SessionMacaroon == "" && device.Serial != "" {
+ err = s.refreshDeviceSession(device)
if err == auth.ErrNoSerial {
// missing serial assertion, log and continue without device authentication
logger.Debugf("cannot set device session: %v", err)
diff --git a/store/store_test.go b/store/store_test.go
index bd0931c80a..90f539b3e0 100644
--- a/store/store_test.go
+++ b/store/store_test.go
@@ -102,17 +102,23 @@ type testAuthContext struct {
}
func (ac *testAuthContext) Device() (*auth.DeviceState, error) {
- return ac.device, nil
+ freshDevice := *ac.device
+ return &freshDevice, nil
}
-func (ac *testAuthContext) UpdateDevice(d *auth.DeviceState) error {
- ac.device = d
- return nil
+func (ac *testAuthContext) UpdateDeviceAuth(d *auth.DeviceState, newSessionMacaroon string) (*auth.DeviceState, error) {
+ ac.c.Assert(d, DeepEquals, ac.device)
+ updated := *ac.device
+ updated.SessionMacaroon = newSessionMacaroon
+ *ac.device = updated
+ return &updated, nil
}
-func (ac *testAuthContext) UpdateUser(u *auth.UserState) error {
+func (ac *testAuthContext) UpdateUserAuth(u *auth.UserState, newDischarges []string) (*auth.UserState, error) {
ac.c.Assert(u, DeepEquals, ac.user)
- return nil
+ updated := *ac.user
+ updated.StoreDischarges = newDischarges
+ return &updated, nil
}
func (ac *testAuthContext) StoreID(fallback string) (string, error) {
@@ -199,6 +205,7 @@ func createTestDevice() *auth.DeviceState {
return &auth.DeviceState{
Brand: "some-brand",
SessionMacaroon: "device-macaroon",
+ Serial: "9999",
}
}
@@ -367,6 +374,40 @@ func (t *remoteRepoTestSuite) TestDoRequestSetsAuth(c *C) {
c.Check(string(responseData), Equals, "response-data")
}
+func (t *remoteRepoTestSuite) TestDoRequestAuthNoSerial(c *C) {
+ mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ c.Check(r.UserAgent(), Equals, userAgent)
+ // check user authorization is set
+ authorization := r.Header.Get("Authorization")
+ c.Check(authorization, Equals, t.expectedAuthorization(c, t.user))
+ // check device authorization was not set
+ c.Check(r.Header.Get("X-Device-Authorization"), Equals, "")
+
+ io.WriteString(w, "response-data")
+ }))
+
+ c.Assert(mockServer, NotNil)
+ defer mockServer.Close()
+
+ // no serial and no device macaroon => no device auth
+ t.device.Serial = ""
+ t.device.SessionMacaroon = ""
+ authContext := &testAuthContext{c: c, device: t.device, user: t.user}
+ repo := New(&Config{}, authContext)
+ c.Assert(repo, NotNil)
+
+ endpoint, _ := url.Parse(mockServer.URL)
+ reqOptions := &requestOptions{Method: "GET", URL: endpoint}
+
+ response, err := repo.doRequest(repo.client, reqOptions, t.user)
+ defer response.Body.Close()
+ c.Assert(err, IsNil)
+
+ responseData, err := ioutil.ReadAll(response.Body)
+ c.Assert(err, IsNil)
+ c.Check(string(responseData), Equals, "response-data")
+}
+
func (t *remoteRepoTestSuite) TestDoRequestRefreshesAuth(c *C) {
refresh, err := makeTestRefreshDischargeResponse()
c.Assert(err, IsNil)
@@ -737,6 +778,10 @@ func (t *remoteRepoTestSuite) TestUbuntuStoreRepositoryStoreIDFromAuthContext(c
func (t *remoteRepoTestSuite) TestUbuntuStoreRepositoryRevision(c *C) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if strings.HasPrefix(r.URL.Path, "/dev/api/snap-purchases") {
+ w.WriteHeader(http.StatusNotFound)
+ return
+ }
c.Check(r.URL.Path, Equals, "/details/hello-world")
c.Check(r.URL.Query(), DeepEquals, url.Values{
"channel": []string{""},
@@ -749,11 +794,12 @@ func (t *remoteRepoTestSuite) TestUbuntuStoreRepositoryRevision(c *C) {
c.Assert(mockServer, NotNil)
defer mockServer.Close()
-
+ purchasesURI, err := url.Parse(mockServer.URL + "/dev/api/snap-purchases/")
detailsURI, err := url.Parse(mockServer.URL + "/details/")
c.Assert(err, IsNil)
cfg := DefaultConfig()
cfg.DetailsURI = detailsURI
+ cfg.PurchasesURI = purchasesURI
repo := New(cfg, nil)
c.Assert(repo, NotNil)
@@ -1528,6 +1574,7 @@ func (t *remoteRepoTestSuite) TestUbuntuStoreRepositoryAssertion(c *C) {
var err error
assertionsURI, err := url.Parse(mockServer.URL + "/assertions/")
c.Assert(err, IsNil)
+
cfg := Config{
AssertionsURI: assertionsURI,
}