Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions overlord/devicestate/devicestate_serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,13 @@ func (s *deviceMgrSerialSuite) mockServer(c *C, reqID string, bhv *devicestatete
bhv.SignSerial = s.signSerial
bhv.ExpectedCapabilities = "serial-stream"

return devicestatetest.MockDeviceService(c, bhv)
mockServer, extraCerts := devicestatetest.MockDeviceService(c, bhv)
fname := filepath.Join(dirs.SnapdStoreSSLCertsDir, "test-server-certs.pem")
err := os.MkdirAll(filepath.Dir(fname), 0755)
c.Assert(err, IsNil)
err = ioutil.WriteFile(fname, extraCerts, 0644)
c.Assert(err, IsNil)
return mockServer
}

func (s *deviceMgrSerialSuite) findBecomeOperationalChange(skipIDs ...string) *state.Change {
Expand Down Expand Up @@ -954,7 +960,7 @@ func (s *deviceMgrSerialSuite) TestDoRequestSerialCertExpired(c *C) {
}

c.Check(chg.Status(), Equals, state.ErrorStatus)
c.Assert(chg.Err(), ErrorMatches, `(?ms).*cannot retrieve request-id for making a request for a serial: Post \"?http://.*/request-id\"?: x509: certificate has expired or is not yet valid.*`)
c.Assert(chg.Err(), ErrorMatches, `(?ms).*cannot retrieve request-id for making a request for a serial: Post \"?https://.*/request-id\"?: x509: certificate has expired or is not yet valid.*`)

var nTentatives int
err := t.Get("pre-poll-tentatives", &nTentatives)
Expand Down Expand Up @@ -1165,7 +1171,7 @@ func (s *deviceMgrSerialSuite) testFullDeviceRegistrationHappyWithHookAndProxy(c
reqID = "REQID-42"
storeVersion = "6"
bhv.PostPreflight = func(c *C, bhv *devicestatetest.DeviceServiceBehavior, w http.ResponseWriter, r *http.Request) {
c.Check(r.Header.Get("X-Snap-Device-Service-URL"), Matches, "http://[^/]*/bad/svc/")
c.Check(r.Header.Get("X-Snap-Device-Service-URL"), Matches, "https://[^/]*/bad/svc/")
c.Check(r.Header.Get("X-Extra-Header"), Equals, "extra")
}
svcPath = "/bad/svc/"
Expand Down Expand Up @@ -1690,14 +1696,13 @@ func (s *deviceMgrSerialSuite) TestNewEnoughProxyParse(c *C) {
s.state.Lock()
defer s.state.Unlock()

log, restore := logger.MockLogger()
defer restore()
os.Setenv("SNAPD_DEBUG", "1")
defer os.Unsetenv("SNAPD_DEBUG")

badURL := &url.URL{Opaque: "%a"} // url.Parse(badURL.String()) needs to fail, which isn't easy :-)
c.Check(devicestate.NewEnoughProxy(s.state, badURL, http.DefaultClient), Equals, false)
c.Check(log.String(), Matches, "(?m).* DEBUG: Cannot check whether proxy store supports a custom serial vault: parse .*")
newEnoughProxy, err := devicestate.NewEnoughProxy(s.state, badURL, http.DefaultClient)
c.Check(err, ErrorMatches, "cannot check whether proxy store supports a custom serial vault: parse .*")
c.Check(newEnoughProxy, Equals, false)
}

func (s *deviceMgrSerialSuite) TestNewEnoughProxy(c *C) {
Expand Down Expand Up @@ -1746,18 +1751,19 @@ func (s *deviceMgrSerialSuite) TestNewEnoughProxy(c *C) {
u, err := url.Parse(server.URL)
c.Assert(err, IsNil)
for _, expected := range expecteds {
log.Reset()
c.Check(devicestate.NewEnoughProxy(s.state, u, http.DefaultClient), Equals, false)
if len(expected) > 0 {
expected = "(?m).* DEBUG: Cannot check whether proxy store supports a custom serial vault: " + expected
newEnoughProxy, err := devicestate.NewEnoughProxy(s.state, u, http.DefaultClient)
if expected != "" {
expected = "cannot check whether proxy store supports a custom serial vault: " + expected
c.Check(err, ErrorMatches, expected)
}
c.Check(log.String(), Matches, expected)
c.Check(newEnoughProxy, Equals, false)
}
c.Check(n, Equals, len(expecteds))

// and success at last
log.Reset()
c.Check(devicestate.NewEnoughProxy(s.state, u, http.DefaultClient), Equals, true)
newEnoughProxy, err := devicestate.NewEnoughProxy(s.state, u, http.DefaultClient)
c.Check(err, IsNil)
c.Check(newEnoughProxy, Equals, true)
c.Check(log.String(), Equals, "")
c.Check(n, Equals, len(expecteds)+1)
}
Expand Down
16 changes: 14 additions & 2 deletions overlord/devicestate/devicestatetest/devicesvc.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package devicestatetest

import (
"bytes"
"encoding/pem"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -60,7 +61,7 @@ const (
serialURLPath = "/api/v1/snaps/auth/devices"
)

func MockDeviceService(c *C, bhv *DeviceServiceBehavior) *httptest.Server {
func MockDeviceService(c *C, bhv *DeviceServiceBehavior) (mockServer *httptest.Server, extraPemEncodedCerts []byte) {
expectedUserAgent := snapdenv.UserAgent()

// default URL paths
Expand All @@ -75,7 +76,8 @@ func MockDeviceService(c *C, bhv *DeviceServiceBehavior) *httptest.Server {

var mu sync.Mutex
count := 0
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// TODO: extract handler func
server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// check.Assert here will produce harder to understand failure
// modes

Expand Down Expand Up @@ -252,4 +254,14 @@ func MockDeviceService(c *C, bhv *DeviceServiceBehavior) *httptest.Server {
}
}
}))
pemEncodedCerts := bytes.NewBuffer(nil)
for _, c1 := range server.TLS.Certificates {
block := &pem.Block{
Type: "CERTIFICATE",
Bytes: c1.Certificate[0],
}
err := pem.Encode(pemEncodedCerts, block)
c.Assert(err, IsNil)
}
return server, pemEncodedCerts.Bytes()
}
8 changes: 7 additions & 1 deletion overlord/devicestate/firstboot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2182,7 +2182,13 @@ func (s *firstBoot16Suite) mockServer(c *C, reqID string) *httptest.Server {
ExpectedCapabilities: "serial-stream",
}

return devicestatetest.MockDeviceService(c, bhv)
mockServer, extraCerts := devicestatetest.MockDeviceService(c, bhv)
fname := filepath.Join(dirs.SnapdStoreSSLCertsDir, "test-server-certs.pem")
err := os.MkdirAll(filepath.Dir(fname), 0755)
c.Assert(err, IsNil)
err = ioutil.WriteFile(fname, extraCerts, 0644)
c.Assert(err, IsNil)
return mockServer
}

func (s *firstBoot16Suite) signSerial(c *C, bhv *devicestatetest.DeviceServiceBehavior, headers map[string]interface{}, body []byte) (serial asserts.Assertion, ancillary []asserts.Assertion, err error) {
Expand Down
35 changes: 20 additions & 15 deletions overlord/devicestate/handlers_serial.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"gopkg.in/tomb.v2"

"github.com/snapcore/snapd/asserts"
"github.com/snapcore/snapd/dirs"
"github.com/snapcore/snapd/httputil"
"github.com/snapcore/snapd/logger"
"github.com/snapcore/snapd/overlord/assertstate"
Expand Down Expand Up @@ -122,37 +123,32 @@ func (m *DeviceManager) doGenerateDeviceKey(t *state.Task, _ *tomb.Tomb) error {
return nil
}

func newEnoughProxy(st *state.State, proxyURL *url.URL, client *http.Client) bool {
func newEnoughProxy(st *state.State, proxyURL *url.URL, client *http.Client) (bool, error) {
st.Unlock()
defer st.Lock()

const prefix = "Cannot check whether proxy store supports a custom serial vault"
const prefix = "cannot check whether proxy store supports a custom serial vault"

req, err := http.NewRequest("HEAD", proxyURL.String(), nil)
if err != nil {
// can't really happen unless proxyURL is somehow broken
logger.Debugf(prefix+": %v", err)
return false
return false, fmt.Errorf(prefix+": %v", err)
}
req.Header.Set("User-Agent", snapdenv.UserAgent())
resp, err := client.Do(req)
if err != nil {
// some sort of network or protocol error
logger.Debugf(prefix+": %v", err)
return false
return false, fmt.Errorf(prefix+": %v", err)
}
resp.Body.Close()
if resp.StatusCode != 200 {
logger.Debugf(prefix+": Head request returned %s.", resp.Status)
return false
return false, fmt.Errorf(prefix+": Head request returned %s.", resp.Status)
}
verstr := resp.Header.Get("Snap-Store-Version")
ver, err := strconv.Atoi(verstr)
if err != nil {
logger.Debugf(prefix+": Bogus Snap-Store-Version header %q.", verstr)
return false
return false, fmt.Errorf(prefix+": Bogus Snap-Store-Version header %q.", verstr)
}
return ver >= 6
return ver >= 6, nil
}

func (cfg *serialRequestConfig) setURLs(proxyURL, svcURL *url.URL) {
Expand Down Expand Up @@ -523,6 +519,9 @@ func getSerial(t *state.Task, regCtx registrationContext, privKey asserts.Privat
MayLogBody: true,
Proxy: proxyConf.Conf,
ProxyConnectHeader: http.Header{"User-Agent": []string{snapdenv.UserAgent()}},
ExtraSSLCerts: &httputil.ExtraSSLCertsFromDir{
Dir: dirs.SnapdStoreSSLCertsDir,
},
})

cfg, err := getSerialRequestConfig(t, regCtx, client)
Expand Down Expand Up @@ -655,9 +654,15 @@ func getSerialRequestConfig(t *state.Task, regCtx registrationContext, client *h
}
}

if proxyURL != nil && svcURL != nil && !newEnoughProxy(st, proxyURL, client) {
logger.Noticef("Proxy store does not support custom serial vault; ignoring the proxy")
proxyURL = nil
if proxyURL != nil && svcURL != nil {
newEnough, err := newEnoughProxy(st, proxyURL, client)
if err != nil {
return nil, err
}
if !newEnough {
logger.Noticef("Proxy store does not support custom serial vault; ignoring the proxy")
proxyURL = nil
}
}

cfg.setURLs(proxyURL, svcURL)
Expand Down
14 changes: 12 additions & 2 deletions overlord/managers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6599,8 +6599,13 @@ func (s *mgrsSuiteCore) TestHappyDeviceRegistrationWithPrepareDeviceHook(c *C) {
SignSerial: signSerial,
}

mockServer := devicestatetest.MockDeviceService(c, bhv)
mockServer, extraCerts := devicestatetest.MockDeviceService(c, bhv)
defer mockServer.Close()
fname := filepath.Join(dirs.SnapdStoreSSLCertsDir, "test-server-certs.pem")
err = os.MkdirAll(filepath.Dir(fname), 0755)
c.Assert(err, IsNil)
err = ioutil.WriteFile(fname, extraCerts, 0644)
c.Assert(err, IsNil)

pDBhv := &devicestatetest.PrepareDeviceBehavior{
DeviceSvcURL: mockServer.URL + "/svc/",
Expand Down Expand Up @@ -6744,8 +6749,13 @@ func (s *mgrsSuiteCore) TestRemodelReregistration(c *C) {
SignSerial: signSerial,
}

mockDeviceService := devicestatetest.MockDeviceService(c, bhv)
mockDeviceService, extraCerts := devicestatetest.MockDeviceService(c, bhv)
defer mockDeviceService.Close()
fname := filepath.Join(dirs.SnapdStoreSSLCertsDir, "test-server-certs.pem")
err = os.MkdirAll(filepath.Dir(fname), 0755)
c.Assert(err, IsNil)
err = ioutil.WriteFile(fname, extraCerts, 0644)
c.Assert(err, IsNil)

r := devicestatetest.MockGadget(c, st, "gadget", snap.R(2), nil)
defer r()
Expand Down