From a3d0a9d25b83cfb47a549fc17205790c6c9d1549 Mon Sep 17 00:00:00 2001 From: James Henstridge Date: Wed, 26 Jun 2019 16:26:50 +0800 Subject: daemon: replace shutdownServer with net/http's native shutdown support --- daemon/daemon.go | 124 ++++++++++++++++---------------------------------- daemon/daemon_test.go | 14 +++--- 2 files changed, 45 insertions(+), 93 deletions(-) (limited to 'daemon') diff --git a/daemon/daemon.go b/daemon/daemon.go index 7e6fba320a..820283ce71 100644 --- a/daemon/daemon.go +++ b/daemon/daemon.go @@ -20,6 +20,7 @@ package daemon import ( + "context" "fmt" "net" "net/http" @@ -57,10 +58,11 @@ var systemdSdNotify = systemd.SdNotify type Daemon struct { Version string overlord *overlord.Overlord + connTracker *connTracker snapdListener net.Listener - snapdServe *shutdownServer + snapdServe *http.Server snapListener net.Listener - snapServe *shutdownServer + snapServe *http.Server tomb tomb.Tomb router *mux.Router standbyOpinions *standby.StandbyOpinions @@ -375,97 +377,42 @@ var ( shutdownTimeout = 25 * time.Second ) -// shutdownServer supplements a http.Server with graceful shutdown. -// TODO: with go1.8 http.Server itself grows a graceful Shutdown method -type shutdownServer struct { - l net.Listener - httpSrv *http.Server - - mu sync.Mutex - conns map[net.Conn]http.ConnState - shuttingDown bool -} - -func newShutdownServer(l net.Listener, h http.Handler) *shutdownServer { - srv := &http.Server{ - Handler: h, - } - ssrv := &shutdownServer{ - l: l, - httpSrv: srv, - conns: make(map[net.Conn]http.ConnState), - } - srv.ConnState = ssrv.trackConn - return ssrv +func httpShutdown(server *http.Server) error { + // We're using the background context here because the tomb's + // context will likely already have been cancelled when we are + // called. + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + return server.Shutdown(ctx) } -func (srv *shutdownServer) Serve() error { - return srv.httpSrv.Serve(srv.l) +type connTracker struct { + mu sync.Mutex + conns map[net.Conn]struct{} } -func (srv *shutdownServer) CanStandby() bool { - srv.mu.Lock() - defer srv.mu.Unlock() +func (ct *connTracker) CanStandby() bool { + ct.mu.Lock() + defer ct.mu.Unlock() - for _, state := range srv.conns { - if state != http.StateIdle { - return false - } - } - return true + return len(ct.conns) == 0 } -func (srv *shutdownServer) trackConn(conn net.Conn, state http.ConnState) { - srv.mu.Lock() - defer srv.mu.Unlock() +func (ct *connTracker) trackConn(conn net.Conn, state http.ConnState) { + ct.mu.Lock() + defer ct.mu.Unlock() // we ignore hijacked connections, if we do things with websockets // we'll need custom shutdown handling for them - if state == http.StateClosed || state == http.StateHijacked { - delete(srv.conns, conn) - return - } - if srv.shuttingDown && state == http.StateIdle { - conn.Close() - delete(srv.conns, conn) - return - } - srv.conns[conn] = state -} - -func (srv *shutdownServer) finishShutdown() error { - toutC := time.After(shutdownTimeout) - - srv.mu.Lock() - defer srv.mu.Unlock() - - srv.shuttingDown = true - for conn, state := range srv.conns { - if state == http.StateIdle { - conn.Close() - delete(srv.conns, conn) - } - } - - doWait := true - for doWait { - if len(srv.conns) == 0 { - return nil - } - srv.mu.Unlock() - select { - case <-time.After(200 * time.Millisecond): - case <-toutC: - doWait = false - } - srv.mu.Lock() + if state == http.StateNew || state == http.StateActive { + ct.conns[conn] = struct{}{} + } else { + delete(ct.conns, conn) } - return fmt.Errorf("cannot gracefully finish, still active connections on %v after %v", srv.l.Addr(), shutdownTimeout) } func (d *Daemon) initStandbyHandling() { d.standbyOpinions = standby.New(d.overlord.State()) - d.standbyOpinions.AddOpinion(d.snapdServe) - d.standbyOpinions.AddOpinion(d.snapServe) + d.standbyOpinions.AddOpinion(d.connTracker) d.standbyOpinions.AddOpinion(d.overlord) d.standbyOpinions.AddOpinion(d.overlord.SnapManager()) d.standbyOpinions.AddOpinion(d.overlord.DeviceManager()) @@ -502,10 +449,15 @@ func (d *Daemon) Start() { } }) - if d.snapListener != nil { - d.snapServe = newShutdownServer(d.snapListener, logit(d.router)) + d.connTracker = &connTracker{conns: make(map[net.Conn]struct{})} + d.snapdServe = &http.Server{ + Handler: logit(d.router), + ConnState: d.connTracker.trackConn, + } + d.snapServe = &http.Server{ + Handler: logit(d.router), + ConnState: d.connTracker.trackConn, } - d.snapdServe = newShutdownServer(d.snapdListener, logit(d.router)) // enable standby handling d.initStandbyHandling() @@ -516,7 +468,7 @@ func (d *Daemon) Start() { d.tomb.Go(func() error { if d.snapListener != nil { d.tomb.Go(func() error { - if err := d.snapServe.Serve(); err != nil && d.tomb.Err() == tomb.ErrStillAlive { + if err := d.snapServe.Serve(d.snapListener); err != http.ErrServerClosed && d.tomb.Err() == tomb.ErrStillAlive { return err } @@ -524,7 +476,7 @@ func (d *Daemon) Start() { }) } - if err := d.snapdServe.Serve(); err != nil && d.tomb.Err() == tomb.ErrStillAlive { + if err := d.snapdServe.Serve(d.snapdListener); err != http.ErrServerClosed && d.tomb.Err() == tomb.ErrStillAlive { return err } @@ -586,9 +538,9 @@ func (d *Daemon) Stop(sigCh chan<- os.Signal) error { time.Sleep(rebootNoticeWait) } - d.tomb.Kill(d.snapdServe.finishShutdown()) + d.tomb.Kill(httpShutdown(d.snapdServe)) if d.snapListener != nil { - d.tomb.Kill(d.snapServe.finishShutdown()) + d.tomb.Kill(httpShutdown(d.snapServe)) } if !restartSystem { diff --git a/daemon/daemon_test.go b/daemon/daemon_test.go index f48205a629..514796347e 100644 --- a/daemon/daemon_test.go +++ b/daemon/daemon_test.go @@ -980,16 +980,16 @@ func (s *daemonSuite) TestRestartIntoSocketModePendingChanges(c *check.C) { c.Check(d.restartSocket, check.Equals, false) } -func (s *daemonSuite) TestShutdownServerCanShutdown(c *check.C) { - shush := newShutdownServer(nil, nil) - c.Check(shush.CanStandby(), check.Equals, true) +func (s *daemonSuite) TestConnTrackerCanShutdown(c *check.C) { + ct := &connTracker{conns: make(map[net.Conn]struct{})} + c.Check(ct.CanStandby(), check.Equals, true) con := &net.IPConn{} - shush.conns[con] = http.StateActive - c.Check(shush.CanStandby(), check.Equals, false) + ct.trackConn(con, http.StateActive) + c.Check(ct.CanStandby(), check.Equals, false) - shush.conns[con] = http.StateIdle - c.Check(shush.CanStandby(), check.Equals, true) + ct.trackConn(con, http.StateIdle) + c.Check(ct.CanStandby(), check.Equals, true) } func doTestReq(c *check.C, cmd *Command, mth string) *httptest.ResponseRecorder { -- cgit v1.2.3