11package  pserver
22
33import  (
4+ "bufio" 
5+ "bytes" 
6+ "crypto/md5" 
7+ "encoding/gob" 
8+ "encoding/hex" 
9+ "encoding/json" 
410"errors" 
511"fmt" 
12+ "os" 
13+ "path/filepath" 
14+ "strconv" 
615"sync" 
16+ "time" 
17+ 
18+ log "github.com/sirupsen/logrus" 
719)
820
921// ElementType is the type of elements of a Parameter. 
@@ -39,26 +51,55 @@ type ParameterWithConfig struct {
3951Config  []byte  // parameter configuration in Proto Buffer format 
4052}
4153
54+ // ParameterCheckpoint is Parameter and State checkpoint 
55+ type  ParameterCheckpoint  struct  {
56+ ParamConfig  ParameterWithConfig 
57+ State  []byte 
58+ }
59+ 
60+ // checkpoint signature 
61+ type  checkpointMeta  struct  {
62+ UUID  string  `json:"uuid"` 
63+ Md5sum  string  `json:"md5sum"` 
64+ Timestamp  string  `json:"timestamp"` 
65+ }
66+ 
67+ // Checkpoint is the pserver shard persist in file 
68+ type  Checkpoint  []ParameterCheckpoint 
69+ 
4270// Gradient is the gradient of the parameter. 
4371type  Gradient  Parameter 
4472
4573// Service is the RPC service for pserver. 
4674type  Service  struct  {
47- initialized  chan  struct {}
48- idx  int 
49- 
50- mu  sync.Mutex 
51- optMap  map [string ]* optimizer 
75+ initialized  chan  struct {}
76+ idx  int 
77+ checkpointInterval  time.Duration 
78+ checkpointPath  string 
79+ client  * EtcdClient 
80+ mu  sync.Mutex 
81+ optMap  map [string ]* optimizer 
5282}
5383
5484// NewService creates a new service, will bypass etcd registration if no 
5585// endpoints specified. 
56- func  NewService (idx  int ) (* Service , error ) {
86+ func  NewService (idx  int ,  seconds   int ,  path   string ,  client   * EtcdClient ,  cp   Checkpoint ) (* Service , error ) {
5787s  :=  & Service {
58- idx : idx ,
88+ idx : idx ,
89+ checkpointInterval : time .Second  *  time .Duration (seconds ),
90+ checkpointPath : path ,
91+ client : client ,
5992}
6093s .optMap  =  make (map [string ]* optimizer )
6194s .initialized  =  make (chan  struct {})
95+ 
96+ if  cp  !=  nil  {
97+ for  _ , item  :=  range  cp  {
98+ p  :=  item .ParamConfig 
99+ st  :=  item .State 
100+ s .optMap [p .Param .Name ] =  newOptimizer (p , st )
101+ }
102+ }
62103return  s , nil 
63104}
64105
@@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
78119// TODO(helin): check if paramWithConfigs.Param.Content is 
79120// properly memory aligned, if not, make copy to a memory 
80121// aligned region. 
81- s .optMap [paramWithConfigs .Param .Name ] =  newOptimizer (paramWithConfigs )
122+ s .optMap [paramWithConfigs .Param .Name ] =  newOptimizer (paramWithConfigs ,  nil )
82123return  nil 
83124}
84125
@@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
139180return  nil 
140181}
141182
142- // Save tells the parameter server to  save parameters.  
143- func  (s  * Service ) Save ( path   string ,  dummy   * int ) error  {
183+ // pserver  save checkpoint  
184+ func  (s  * Service ) doCheckpoint ( ) error  {
144185<- s .initialized 
186+ s .mu .Lock ()
187+ defer  s .mu .Unlock ()
188+ 
189+ cp  :=  make ([]ParameterCheckpoint , 0 , len (s .optMap ))
190+ index  :=  0 
191+ for  name , opt  :=  range  s .optMap  {
192+ var  pc  ParameterCheckpoint 
193+ pc .ParamConfig .Param .Name  =  name 
194+ pc .ParamConfig .Param .ElementType  =  opt .elementType 
195+ pc .ParamConfig .Param .Content  =  opt .GetWeights ()
196+ pc .State  =  opt .GetStates ()
197+ cp [index ] =  pc 
198+ index ++ 
199+ }
200+ var  buf  bytes.Buffer 
201+ encoder  :=  gob .NewEncoder (& buf )
202+ err  :=  encoder .Encode (cp )
203+ if  err  !=  nil  {
204+ return  err 
205+ }
206+ 
207+ cpMeta  :=  checkpointMeta {}
208+ cpMeta .UUID  =  s .checkpointPath  +  strconv .Itoa (s .idx )
209+ cpMeta .Timestamp  =  time .Now ().String ()
210+ h  :=  md5 .New ()
211+ cpMeta .Md5sum  =  hex .EncodeToString (h .Sum (buf .Bytes ()))
145212
146- // TODO 
213+ cpMetajson , _  :=  json .Marshal (cpMeta )
214+ err  =  s .client .PutKey (filepath .Join (PsCheckpoint , strconv .Itoa (s .idx )), cpMetajson , 3 )
215+ if  err  !=  nil  {
216+ return  err 
217+ }
218+ if  _ , err  =  os .Stat (cpMeta .UUID ); os .IsNotExist (err ) {
219+ log .Info ("checkpoint does not exists." )
220+ } else  {
221+ err  =  os .Remove (cpMeta .UUID )
222+ log .Infof ("checkpoint %s already exsits, removing " , cpMeta .UUID )
223+ }
224+ f , err  :=  os .Create (cpMeta .UUID )
225+ defer  f .Close ()
226+ if  err  !=  nil  {
227+ return  err 
228+ }
229+ writer  :=  bufio .NewWriter (f )
230+ _ , err  =  writer .Write (buf .Bytes ())
231+ writer .Flush ()
232+ if  err  !=  nil  {
233+ return  err 
234+ }
147235return  nil 
148236}
0 commit comments