Skip to content

Commit 15f021a

Browse files
authored
Merge pull request #2716 from dzhwinter/save_state
Pserver Save state
2 parents 8c615e8 + e8296ff commit 15f021a

File tree

6 files changed

+146
-31
lines changed

6 files changed

+146
-31
lines changed

go/cmd/pserver/pserver.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ func main() {
2020
"comma separated endpoint string for pserver to connect to etcd")
2121
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
2222
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
23+
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
24+
checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds")
2325
logLevel := flag.String("log-level", "info",
2426
"log level, possible values: debug, info, warning, error, fatal, panic")
2527
flag.Parse()
@@ -31,18 +33,20 @@ func main() {
3133
log.SetLevel(level)
3234

3335
var idx int
36+
var cp pserver.Checkpoint
37+
var e *pserver.EtcdClient
3438
if *index >= 0 {
3539
idx = *index
3640
} else {
3741
timeout := time.Second * time.Duration((*etcdTimeout))
38-
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
42+
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
3943
idx, err = e.Register()
4044
if err != nil {
4145
panic(err)
4246
}
4347
}
4448

45-
s, err := pserver.NewService(idx)
49+
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
4650
if err != nil {
4751
panic(err)
4852
}

go/pserver/etcd_client.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ const (
1818
PsDesired = "/ps_desired"
1919
// PsAddr is the base dir for pserver to store their addr
2020
PsPath = "/ps/"
21+
// PsCheckpoint is the etcd path for store checkpoints information
22+
PsCheckpoint = "/checkpoints/"
2123
)
2224

2325
// EtcdClient is the etcd client that the pserver uses for fault
@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
186188

187189
return idx, nil
188190
}
191+
192+
// PutKey put into etcd with value by key specified
193+
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error {
194+
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
195+
_, err := e.etcdClient.Put(ctx, key, string(value))
196+
cancel()
197+
if err != nil {
198+
return err
199+
}
200+
return nil
201+
}

go/pserver/optimizer.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,28 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
3535
return (*[1 << 30]byte)(p)[:len:len]
3636
}
3737

38-
func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer {
38+
func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
3939
o := &optimizer{}
4040
o.elementType = paramWithConfigs.Param.ElementType
4141
p := paramWithConfigs.Param
4242
c := paramWithConfigs.Config
43+
s := State
4344
log.WithFields(log.Fields{
4445
"ElementType": p.ElementType,
4546
"ParamSize": len(p.Content),
4647
"ConfigSize": len(c),
48+
"StateSize": len(s),
4749
}).Info("New Optimizer Created with config:")
4850
var cbuffer unsafe.Pointer
4951
cbuffer = C.malloc(C.size_t(len(p.Content)))
5052
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
53+
var cstate unsafe.Pointer
54+
if len(s) != 0 {
55+
cstate = unsafe.Pointer(&s[0])
56+
}
57+
5158
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
52-
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float),
53-
(*C.char)(nullPtr), 0)
59+
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), (*C.char)(cstate), C.int(len(s)))
5460
return o
5561
}
5662

@@ -60,6 +66,12 @@ func (o *optimizer) GetWeights() []byte {
6066
return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float)
6167
}
6268

69+
func (o *optimizer) GetStates() []byte {
70+
var cbuffer *C.char
71+
cbuffer_len := C.paddle_optimizer_get_state(o.opt, &cbuffer)
72+
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbuffer_len))
73+
}
74+
6375
func (o *optimizer) UpdateParameter(g Gradient) error {
6476
if o.elementType != g.ElementType {
6577
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)

go/pserver/optimizer_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
1919
Param: p,
2020
Config: config,
2121
}
22-
o := newOptimizer(param)
22+
o := newOptimizer(param, nil)
2323
o.Cleanup()
2424
}

go/pserver/service.go

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
11
package pserver
22

33
import (
4+
"bufio"
5+
"bytes"
6+
"crypto/md5"
7+
"encoding/gob"
8+
"encoding/hex"
9+
"encoding/json"
410
"errors"
511
"fmt"
12+
"os"
13+
"path/filepath"
14+
"strconv"
615
"sync"
16+
"time"
17+
18+
log "github.com/sirupsen/logrus"
719
)
820

921
// ElementType is the type of elements of a Parameter.
@@ -39,26 +51,55 @@ type ParameterWithConfig struct {
3951
Config []byte // parameter configuration in Proto Buffer format
4052
}
4153

54+
// ParameterCheckpoint is Parameter and State checkpoint
55+
type ParameterCheckpoint struct {
56+
ParamConfig ParameterWithConfig
57+
State []byte
58+
}
59+
60+
// checkpoint signature
61+
type checkpointMeta struct {
62+
UUID string `json:"uuid"`
63+
Md5sum string `json:"md5sum"`
64+
Timestamp string `json:"timestamp"`
65+
}
66+
67+
// Checkpoint is the pserver shard persist in file
68+
type Checkpoint []ParameterCheckpoint
69+
4270
// Gradient is the gradient of the parameter.
4371
type Gradient Parameter
4472

4573
// Service is the RPC service for pserver.
4674
type Service struct {
47-
initialized chan struct{}
48-
idx int
49-
50-
mu sync.Mutex
51-
optMap map[string]*optimizer
75+
initialized chan struct{}
76+
idx int
77+
checkpointInterval time.Duration
78+
checkpointPath string
79+
client *EtcdClient
80+
mu sync.Mutex
81+
optMap map[string]*optimizer
5282
}
5383

5484
// NewService creates a new service, will bypass etcd registration if no
5585
// endpoints specified.
56-
func NewService(idx int) (*Service, error) {
86+
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
5787
s := &Service{
58-
idx: idx,
88+
idx: idx,
89+
checkpointInterval: time.Second * time.Duration(seconds),
90+
checkpointPath: path,
91+
client: client,
5992
}
6093
s.optMap = make(map[string]*optimizer)
6194
s.initialized = make(chan struct{})
95+
96+
if cp != nil {
97+
for _, item := range cp {
98+
p := item.ParamConfig
99+
st := item.State
100+
s.optMap[p.Param.Name] = newOptimizer(p, st)
101+
}
102+
}
62103
return s, nil
63104
}
64105

@@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
78119
// TODO(helin): check if paramWithConfigs.Param.Content is
79120
// properly memory aligned, if not, make copy to a memory
80121
// aligned region.
81-
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs)
122+
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
82123
return nil
83124
}
84125

@@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
139180
return nil
140181
}
141182

142-
// Save tells the parameter server to save parameters.
143-
func (s *Service) Save(path string, dummy *int) error {
183+
// pserver save checkpoint
184+
func (s *Service) doCheckpoint() error {
144185
<-s.initialized
186+
s.mu.Lock()
187+
defer s.mu.Unlock()
188+
189+
cp := make([]ParameterCheckpoint, 0, len(s.optMap))
190+
index := 0
191+
for name, opt := range s.optMap {
192+
var pc ParameterCheckpoint
193+
pc.ParamConfig.Param.Name = name
194+
pc.ParamConfig.Param.ElementType = opt.elementType
195+
pc.ParamConfig.Param.Content = opt.GetWeights()
196+
pc.State = opt.GetStates()
197+
cp[index] = pc
198+
index++
199+
}
200+
var buf bytes.Buffer
201+
encoder := gob.NewEncoder(&buf)
202+
err := encoder.Encode(cp)
203+
if err != nil {
204+
return err
205+
}
206+
207+
cpMeta := checkpointMeta{}
208+
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
209+
cpMeta.Timestamp = time.Now().String()
210+
h := md5.New()
211+
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes()))
145212

146-
// TODO
213+
cpMetajson, _ := json.Marshal(cpMeta)
214+
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
215+
if err != nil {
216+
return err
217+
}
218+
if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
219+
log.Info("checkpoint does not exists.")
220+
} else {
221+
err = os.Remove(cpMeta.UUID)
222+
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
223+
}
224+
f, err := os.Create(cpMeta.UUID)
225+
defer f.Close()
226+
if err != nil {
227+
return err
228+
}
229+
writer := bufio.NewWriter(f)
230+
_, err = writer.Write(buf.Bytes())
231+
writer.Flush()
232+
if err != nil {
233+
return err
234+
}
147235
return nil
148236
}

go/pserver/service_test.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ const (
1515
)
1616

1717
func TestServiceFull(t *testing.T) {
18-
s, err := pserver.NewService(0)
18+
var cp pserver.Checkpoint
19+
s, err := pserver.NewService(0, 1, "", nil, cp)
1920
if err != nil {
2021
t.Error(err)
2122
}
@@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) {
8687
}
8788

8889
func TestMultipleInit(t *testing.T) {
89-
s, err := pserver.NewService(0)
90+
var cp pserver.Checkpoint
91+
s, err := pserver.NewService(0, 1, "", nil, cp)
9092
if err != nil {
9193
t.Error(err)
9294
}
@@ -102,15 +104,17 @@ func TestMultipleInit(t *testing.T) {
102104
}
103105

104106
func TestUninitialized(t *testing.T) {
105-
s, err := pserver.NewService(0)
107+
var cp pserver.Checkpoint
108+
s, err := pserver.NewService(0, 1, "", nil, cp)
106109
err = s.SendGrad(pserver.Gradient{}, nil)
107110
if err.Error() != pserver.Uninitialized {
108111
t.FailNow()
109112
}
110113
}
111114

112115
func TestBlockUntilInitialized(t *testing.T) {
113-
s, err := pserver.NewService(0)
116+
var cp pserver.Checkpoint
117+
s, err := pserver.NewService(0, 1, "", nil, cp)
114118
if err != nil {
115119
t.Error(err)
116120
}
@@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) {
128132
ch <- struct{}{}
129133
}()
130134

131-
wg.Add(1)
132-
go func() {
133-
err := s.Save("", nil)
134-
if err != nil {
135-
errCh <- err
136-
}
137-
wg.Done()
138-
ch <- struct{}{}
139-
}()
140-
141135
time.Sleep(50 * time.Millisecond)
142136

143137
select {
@@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) {
170164

171165
wg.Wait()
172166
}
167+
168+
func TestCheckpointSpeed(t *testing.T) {
169+
//TODO(zhihong): test speed
170+
}

0 commit comments

Comments
 (0)