Skip to content
8 changes: 6 additions & 2 deletions go/cmd/pserver/pserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ func main() {
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse()
Expand All @@ -31,18 +33,20 @@ func main() {
log.SetLevel(level)

var idx int
var cp pserver.Checkpoint
var e *pserver.EtcdClient
if *index >= 0 {
idx = *index
} else {
timeout := time.Second * time.Duration((*etcdTimeout))
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err = e.Register()
if err != nil {
panic(err)
}
}

s, err := pserver.NewService(idx)
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
if err != nil {
panic(err)
}
Expand Down
13 changes: 13 additions & 0 deletions go/pserver/etcd_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ const (
PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr
PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/"
)

// EtcdClient is the etcd client that the pserver uses for fault
Expand Down Expand Up @@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {

return idx, nil
}

// PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
_, err := e.etcdClient.Put(ctx, key, string(value))
cancel()
if err != nil {
return err
}
return nil
}
18 changes: 15 additions & 3 deletions go/pserver/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,28 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return (*[1 << 30]byte)(p)[:len:len]
}

func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer {
func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
o := &optimizer{}
o.elementType = paramWithConfigs.Param.ElementType
p := paramWithConfigs.Param
c := paramWithConfigs.Config
s := State
log.WithFields(log.Fields{
"ElementType": p.ElementType,
"ParamSize": len(p.Content),
"ConfigSize": len(c),
"StateSize": len(s),
}).Info("New Optimizer Created with config:")
var cbuffer unsafe.Pointer
cbuffer = C.malloc(C.size_t(len(p.Content)))
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
var cstate unsafe.Pointer
if len(s) != 0 {
cstate = unsafe.Pointer(&s[0])
}

o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float),
(*C.char)(nullPtr), 0)
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), (*C.char)(cstate), C.int(len(s)))
return o
}

Expand All @@ -60,6 +66,12 @@ func (o *optimizer) GetWeights() []byte {
return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float)
}

func (o *optimizer) GetStates() []byte {
var cbuffer *C.char
cbuffer_len := C.paddle_optimizer_get_state(o.opt, &cbuffer)
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbuffer_len))
}

func (o *optimizer) UpdateParameter(g Gradient) error {
if o.elementType != g.ElementType {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)
Expand Down
2 changes: 1 addition & 1 deletion go/pserver/optimizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
Param: p,
Config: config,
}
o := newOptimizer(param)
o := newOptimizer(param, nil)
o.Cleanup()
}
110 changes: 99 additions & 11 deletions go/pserver/service.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
package pserver

import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"sync"
"time"

log "github.com/sirupsen/logrus"
)

// ElementType is the type of elements of a Parameter.
Expand Down Expand Up @@ -39,26 +51,55 @@ type ParameterWithConfig struct {
Config []byte // parameter configuration in Proto Buffer format
}

// ParameterCheckpoint is Parameter and State checkpoint
type ParameterCheckpoint struct {
ParamConfig ParameterWithConfig
State []byte
}

// checkpoint signature
type checkpointMeta struct {
UUID string `json:"uuid"`
Md5sum string `json:"md5sum"`
Timestamp string `json:"timestamp"`
}

// Checkpoint is the pserver shard persist in file
type Checkpoint []ParameterCheckpoint

// Gradient is the gradient of the parameter.
type Gradient Parameter

// Service is the RPC service for pserver.
type Service struct {
initialized chan struct{}
idx int

mu sync.Mutex
optMap map[string]*optimizer
initialized chan struct{}
idx int
checkpointInterval time.Duration
checkpointPath string
client *EtcdClient
mu sync.Mutex
optMap map[string]*optimizer
}

// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func NewService(idx int) (*Service, error) {
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
idx: idx,
checkpointInterval: time.Second * time.Duration(seconds),
checkpointPath: path,
client: client,
}
s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{})

if cp != nil {
for _, item := range cp {
p := item.ParamConfig
st := item.State
s.optMap[p.Param.Name] = newOptimizer(p, st)
}
}
return s, nil
}

Expand All @@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory
// aligned region.
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs)
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
return nil
}

Expand Down Expand Up @@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return nil
}

// Save tells the parameter server to save parameters.
func (s *Service) Save(path string, dummy *int) error {
// pserver save checkpoint
func (s *Service) doCheckpoint() error {
<-s.initialized
s.mu.Lock()
defer s.mu.Unlock()

cp := make([]ParameterCheckpoint, 0, len(s.optMap))
index := 0
for name, opt := range s.optMap {
var pc ParameterCheckpoint
pc.ParamConfig.Param.Name = name
pc.ParamConfig.Param.ElementType = opt.elementType
pc.ParamConfig.Param.Content = opt.GetWeights()
pc.State = opt.GetStates()
cp[index] = pc
index++
}
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err := encoder.Encode(cp)
if err != nil {
return err
}

cpMeta := checkpointMeta{}
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
cpMeta.Timestamp = time.Now().String()
h := md5.New()
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes()))

// TODO
cpMetajson, _ := json.Marshal(cpMeta)
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
if err != nil {
return err
}
if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
log.Info("checkpoint does not exists.")
} else {
err = os.Remove(cpMeta.UUID)
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
}
f, err := os.Create(cpMeta.UUID)
defer f.Close()
if err != nil {
return err
}
writer := bufio.NewWriter(f)
_, err = writer.Write(buf.Bytes())
writer.Flush()
if err != nil {
return err
}
return nil
}
26 changes: 12 additions & 14 deletions go/pserver/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ const (
)

func TestServiceFull(t *testing.T) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) {
}

func TestMultipleInit(t *testing.T) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil {
t.Error(err)
}
Expand All @@ -102,15 +104,17 @@ func TestMultipleInit(t *testing.T) {
}

func TestUninitialized(t *testing.T) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized {
t.FailNow()
}
}

func TestBlockUntilInitialized(t *testing.T) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil {
t.Error(err)
}
Expand All @@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) {
ch <- struct{}{}
}()

wg.Add(1)
go func() {
err := s.Save("", nil)
if err != nil {
errCh <- err
}
wg.Done()
ch <- struct{}{}
}()

time.Sleep(50 * time.Millisecond)

select {
Expand Down Expand Up @@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) {

wg.Wait()
}

func TestCheckpointSpeed(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speed can be tested with benchmark. Here is an example: https://dave.cheney.net/2013/06/30/how-to-write-benchmarks-in-go

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave a TODO here, will be tested after reaching an agreement with @Yancey1989 's recover logic.

//TODO(zhihong): test speed
}