diff options
| -rw-r--r-- | overlord/auth/auth.go | 44 | ||||
| -rw-r--r-- | overlord/auth/auth_test.go | 80 | ||||
| -rw-r--r-- | store/store.go | 58 | ||||
| -rw-r--r-- | store/store_test.go | 61 |
4 files changed, 196 insertions, 47 deletions
diff --git a/overlord/auth/auth.go b/overlord/auth/auth.go index 2f1fa46749..ad99a566a1 100644 --- a/overlord/auth/auth.go +++ b/overlord/auth/auth.go @@ -228,15 +228,17 @@ type DeviceAssertions interface { } var ( + // ErrNoSerial indicates that a device serial is not set yet. ErrNoSerial = errors.New("no device serial yet") ) // An AuthContext exposes authorization data and handles its updates. type AuthContext interface { Device() (*DeviceState, error) - UpdateDevice(device *DeviceState) error - UpdateUser(user *UserState) error + UpdateDeviceAuth(device *DeviceState, sessionMacaroon string) (actual *DeviceState, err error) + + UpdateUserAuth(user *UserState, discharges []string) (actual *UserState, err error) StoreID(fallback string) (string, error) @@ -265,20 +267,46 @@ func (ac *authContext) Device() (*DeviceState, error) { return Device(ac.state) } -// UpdateDevice updates device in state. -func (ac *authContext) UpdateDevice(device *DeviceState) error { +// UpdateDeviceAuth updates the device auth details in state. +// The last update wins but other device details are left unchanged. +// It returns the updated device state value. +func (ac *authContext) UpdateDeviceAuth(device *DeviceState, newSessionMacaroon string) (actual *DeviceState, err error) { ac.state.Lock() defer ac.state.Unlock() - return SetDevice(ac.state, device) + cur, err := Device(ac.state) + if err != nil { + return nil, err + } + + // just do it, last update wins + cur.SessionMacaroon = newSessionMacaroon + if err := SetDevice(ac.state, cur); err != nil { + return nil, fmt.Errorf("internal error: cannot update just read device state: %v", err) + } + + return cur, nil } -// UpdateUser updates user in state. -func (ac *authContext) UpdateUser(user *UserState) error { +// UpdateUserAuth updates the user auth details in state. +// The last update wins but other user details are left unchanged. +// It returns the updated user state value. +func (ac *authContext) UpdateUserAuth(user *UserState, newDischarges []string) (actual *UserState, err error) { ac.state.Lock() defer ac.state.Unlock() - return UpdateUser(ac.state, user) + cur, err := User(ac.state, user.ID) + if err != nil { + return nil, err + } + + // just do it, last update wins + cur.StoreDischarges = newDischarges + if err := UpdateUser(ac.state, cur); err != nil { + return nil, fmt.Errorf("internal error: cannot update just read user state: %v", err) + } + + return cur, nil } // StoreID returns the store id according to system state or diff --git a/overlord/auth/auth_test.go b/overlord/auth/auth_test.go index 6b0d5ffdd6..4a44107c56 100644 --- a/overlord/auth/auth_test.go +++ b/overlord/auth/auth_test.go @@ -283,16 +283,15 @@ func (as *authSuite) TestSetDevice(c *C) { c.Check(device, DeepEquals, &auth.DeviceState{Brand: "some-brand"}) } -func (as *authSuite) TestAuthContextUpdateUser(c *C) { +func (as *authSuite) TestAuthContextUpdateUserAuth(c *C) { as.state.Lock() user, _ := auth.NewUser(as.state, "username", "macaroon", []string{"discharge"}) as.state.Unlock() - user.Username = "different" - user.StoreDischarges = []string{"updated-discharge"} + newDischarges := []string{"updated-discharge"} authContext := auth.NewAuthContext(as.state, nil) - err := authContext.UpdateUser(user) + user, err := authContext.UpdateUserAuth(user, newDischarges) c.Check(err, IsNil) as.state.Lock() @@ -300,9 +299,43 @@ func (as *authSuite) TestAuthContextUpdateUser(c *C) { as.state.Unlock() c.Check(err, IsNil) c.Check(userFromState, DeepEquals, user) + c.Check(userFromState.Discharges, DeepEquals, []string{"discharge"}) + c.Check(user.StoreDischarges, DeepEquals, newDischarges) +} + +func (as *authSuite) TestAuthContextUpdateUserAuthOtherUpdate(c *C) { + as.state.Lock() + user, _ := auth.NewUser(as.state, "username", "macaroon", []string{"discharge"}) + otherUpdateUser := *user + otherUpdateUser.Macaroon = "macaroon2" + otherUpdateUser.StoreDischarges = []string{"other-discharges"} + err := auth.UpdateUser(as.state, &otherUpdateUser) + as.state.Unlock() + c.Assert(err, IsNil) + + newDischarges := []string{"updated-discharge"} + + authContext := auth.NewAuthContext(as.state, nil) + // last discharges win + curUser, err := authContext.UpdateUserAuth(user, newDischarges) + c.Assert(err, IsNil) + + as.state.Lock() + userFromState, err := auth.User(as.state, user.ID) + as.state.Unlock() + c.Check(err, IsNil) + c.Check(userFromState, DeepEquals, curUser) + c.Check(curUser, DeepEquals, &auth.UserState{ + ID: user.ID, + Username: "username", + Macaroon: "macaroon2", + Discharges: []string{"discharge"}, + StoreMacaroon: "macaroon", + StoreDischarges: newDischarges, + }) } -func (as *authSuite) TestAuthContextUpdateUserInvalid(c *C) { +func (as *authSuite) TestAuthContextUpdateUserAuthInvalid(c *C) { as.state.Lock() _, _ = auth.NewUser(as.state, "username", "macaroon", []string{"discharge"}) as.state.Unlock() @@ -314,7 +347,7 @@ func (as *authSuite) TestAuthContextUpdateUserInvalid(c *C) { } authContext := auth.NewAuthContext(as.state, nil) - err := authContext.UpdateUser(user) + _, err := authContext.UpdateUserAuth(user, nil) c.Assert(err, ErrorMatches, "invalid user") } @@ -340,21 +373,50 @@ func (as *authSuite) TestAuthContextDevice(c *C) { c.Check(deviceFromState, DeepEquals, device) } -func (as *authSuite) TestAuthContextUpdateDevice(c *C) { +func (as *authSuite) TestAuthContextUpdateDeviceAuth(c *C) { as.state.Lock() device, err := auth.Device(as.state) as.state.Unlock() c.Check(err, IsNil) c.Check(device, DeepEquals, &auth.DeviceState{}) + sessionMacaroon := "the-device-macaroon" + authContext := auth.NewAuthContext(as.state, nil) - device.SessionMacaroon = "the-device-macaroon" - err = authContext.UpdateDevice(device) + device, err = authContext.UpdateDeviceAuth(device, sessionMacaroon) c.Check(err, IsNil) deviceFromState, err := authContext.Device() c.Check(err, IsNil) c.Check(deviceFromState, DeepEquals, device) + c.Check(deviceFromState.SessionMacaroon, DeepEquals, sessionMacaroon) +} + +func (as *authSuite) TestAuthContextUpdateDeviceAuthOtherUpdate(c *C) { + as.state.Lock() + device, _ := auth.Device(as.state) + otherUpdateDevice := *device + otherUpdateDevice.SessionMacaroon = "othe-session-macaroon" + otherUpdateDevice.KeyID = "KEYID" + err := auth.SetDevice(as.state, &otherUpdateDevice) + as.state.Unlock() + c.Check(err, IsNil) + + sessionMacaroon := "the-device-macaroon" + + authContext := auth.NewAuthContext(as.state, nil) + curDevice, err := authContext.UpdateDeviceAuth(device, sessionMacaroon) + c.Assert(err, IsNil) + + as.state.Lock() + deviceFromState, err := auth.Device(as.state) + as.state.Unlock() + c.Check(err, IsNil) + c.Check(deviceFromState, DeepEquals, curDevice) + c.Check(curDevice, DeepEquals, &auth.DeviceState{ + KeyID: "KEYID", + SessionMacaroon: sessionMacaroon, + }) } func (as *authSuite) TestAuthContextStoreIDFallback(c *C) { diff --git a/store/store.go b/store/store.go index 2eaa7c1cb8..9a715a649a 100644 --- a/store/store.go +++ b/store/store.go @@ -367,52 +367,53 @@ func authenticateUser(r *http.Request, user *auth.UserState) { r.Header.Set("Authorization", buf.String()) } -// refreshMacaroon will request a refreshed discharge macaroon for the user -func refreshMacaroon(user *auth.UserState) error { +// refreshDischarges will request refreshed discharge macaroons for the user +func refreshDischarges(user *auth.UserState) ([]string, error) { + newDischarges := make([]string, len(user.StoreDischarges)) for i, d := range user.StoreDischarges { discharge, err := MacaroonDeserialize(d) if err != nil { - return err + return nil, err } - if discharge.Location() == UbuntuoneLocation { - refreshedDischarge, err := RefreshDischargeMacaroon(d) - if err != nil { - return err - } - user.StoreDischarges[i] = refreshedDischarge + if discharge.Location() != UbuntuoneLocation { + newDischarges[i] = d + continue } + + refreshedDischarge, err := RefreshDischargeMacaroon(d) + if err != nil { + return nil, err + } + newDischarges[i] = refreshedDischarge } - return nil + return newDischarges, nil } // refreshUser will refresh user discharge macaroon and update state func (s *Store) refreshUser(user *auth.UserState) error { - err := refreshMacaroon(user) + newDischarges, err := refreshDischarges(user) if err != nil { return err } if s.authContext != nil { - err = s.authContext.UpdateUser(user) + curUser, err := s.authContext.UpdateUserAuth(user, newDischarges) if err != nil { return err } + // update in place + *user = *curUser } return nil } // refreshDeviceSession will set or refresh the device session in the state -func (s *Store) refreshDeviceSession() error { +func (s *Store) refreshDeviceSession(device *auth.DeviceState) error { if s.authContext == nil { return fmt.Errorf("internal error: no authContext") } - device, err := s.authContext.Device() - if err != nil { - return err - } - nonce, err := RequestStoreDeviceNonce() if err != nil { return err @@ -428,11 +429,12 @@ func (s *Store) refreshDeviceSession() error { return err } - device.SessionMacaroon = session - err = s.authContext.UpdateDevice(device) + curDevice, err := s.authContext.UpdateDeviceAuth(device, session) if err != nil { return err } + // update in place + *device = *curDevice return nil } @@ -492,7 +494,15 @@ func (s *Store) doRequest(client *http.Client, reqOptions *requestOptions, user } if strings.Contains(wwwAuth, "refresh_device_session=1") { // refresh device session - err = s.refreshDeviceSession() + if s.authContext == nil { + return nil, fmt.Errorf("internal error: no authContext") + } + device, err := s.authContext.Device() + if err != nil { + return nil, err + } + + err = s.refreshDeviceSession(device) if err != nil { return nil, err } @@ -526,8 +536,10 @@ func (s *Store) newRequest(reqOptions *requestOptions, user *auth.UserState) (*h if err != nil { return nil, err } - if device.SessionMacaroon == "" { - err = s.refreshDeviceSession() + // we don't have a session yet but have a serial, try + // to get a session + if device.SessionMacaroon == "" && device.Serial != "" { + err = s.refreshDeviceSession(device) if err == auth.ErrNoSerial { // missing serial assertion, log and continue without device authentication logger.Debugf("cannot set device session: %v", err) diff --git a/store/store_test.go b/store/store_test.go index bd0931c80a..90f539b3e0 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -102,17 +102,23 @@ type testAuthContext struct { } func (ac *testAuthContext) Device() (*auth.DeviceState, error) { - return ac.device, nil + freshDevice := *ac.device + return &freshDevice, nil } -func (ac *testAuthContext) UpdateDevice(d *auth.DeviceState) error { - ac.device = d - return nil +func (ac *testAuthContext) UpdateDeviceAuth(d *auth.DeviceState, newSessionMacaroon string) (*auth.DeviceState, error) { + ac.c.Assert(d, DeepEquals, ac.device) + updated := *ac.device + updated.SessionMacaroon = newSessionMacaroon + *ac.device = updated + return &updated, nil } -func (ac *testAuthContext) UpdateUser(u *auth.UserState) error { +func (ac *testAuthContext) UpdateUserAuth(u *auth.UserState, newDischarges []string) (*auth.UserState, error) { ac.c.Assert(u, DeepEquals, ac.user) - return nil + updated := *ac.user + updated.StoreDischarges = newDischarges + return &updated, nil } func (ac *testAuthContext) StoreID(fallback string) (string, error) { @@ -199,6 +205,7 @@ func createTestDevice() *auth.DeviceState { return &auth.DeviceState{ Brand: "some-brand", SessionMacaroon: "device-macaroon", + Serial: "9999", } } @@ -367,6 +374,40 @@ func (t *remoteRepoTestSuite) TestDoRequestSetsAuth(c *C) { c.Check(string(responseData), Equals, "response-data") } +func (t *remoteRepoTestSuite) TestDoRequestAuthNoSerial(c *C) { + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.Check(r.UserAgent(), Equals, userAgent) + // check user authorization is set + authorization := r.Header.Get("Authorization") + c.Check(authorization, Equals, t.expectedAuthorization(c, t.user)) + // check device authorization was not set + c.Check(r.Header.Get("X-Device-Authorization"), Equals, "") + + io.WriteString(w, "response-data") + })) + + c.Assert(mockServer, NotNil) + defer mockServer.Close() + + // no serial and no device macaroon => no device auth + t.device.Serial = "" + t.device.SessionMacaroon = "" + authContext := &testAuthContext{c: c, device: t.device, user: t.user} + repo := New(&Config{}, authContext) + c.Assert(repo, NotNil) + + endpoint, _ := url.Parse(mockServer.URL) + reqOptions := &requestOptions{Method: "GET", URL: endpoint} + + response, err := repo.doRequest(repo.client, reqOptions, t.user) + defer response.Body.Close() + c.Assert(err, IsNil) + + responseData, err := ioutil.ReadAll(response.Body) + c.Assert(err, IsNil) + c.Check(string(responseData), Equals, "response-data") +} + func (t *remoteRepoTestSuite) TestDoRequestRefreshesAuth(c *C) { refresh, err := makeTestRefreshDischargeResponse() c.Assert(err, IsNil) @@ -737,6 +778,10 @@ func (t *remoteRepoTestSuite) TestUbuntuStoreRepositoryStoreIDFromAuthContext(c func (t *remoteRepoTestSuite) TestUbuntuStoreRepositoryRevision(c *C) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/dev/api/snap-purchases") { + w.WriteHeader(http.StatusNotFound) + return + } c.Check(r.URL.Path, Equals, "/details/hello-world") c.Check(r.URL.Query(), DeepEquals, url.Values{ "channel": []string{""}, @@ -749,11 +794,12 @@ func (t *remoteRepoTestSuite) TestUbuntuStoreRepositoryRevision(c *C) { c.Assert(mockServer, NotNil) defer mockServer.Close() - + purchasesURI, err := url.Parse(mockServer.URL + "/dev/api/snap-purchases/") detailsURI, err := url.Parse(mockServer.URL + "/details/") c.Assert(err, IsNil) cfg := DefaultConfig() cfg.DetailsURI = detailsURI + cfg.PurchasesURI = purchasesURI repo := New(cfg, nil) c.Assert(repo, NotNil) @@ -1528,6 +1574,7 @@ func (t *remoteRepoTestSuite) TestUbuntuStoreRepositoryAssertion(c *C) { var err error assertionsURI, err := url.Parse(mockServer.URL + "/assertions/") c.Assert(err, IsNil) + cfg := Config{ AssertionsURI: assertionsURI, } |
