2222from torchrl .collectors import SyncDataCollector
2323from torchrl .data import NonTensor
2424from torchrl .data .replay_buffers .samplers import SliceSamplerWithoutReplacement
25- from torchrl .data .tensor_specs import Composite
25+ from torchrl .data .tensor_specs import Box , Composite , TensorSpec
2626from torchrl .envs import ChessEnv
2727from torchrl .envs .transforms import (
2828 ConditionalPolicySwitch ,
@@ -75,6 +75,41 @@ def _reset(self, tensordict, tensordict_reset):
7575 return tensordict_reset
7676
7777
78+ class Score (Transform ):
79+ def __init__ (self , input_queue , output_queue ):
80+ super ().__init__ ()
81+ self .input_queue = input_queue
82+ self .output_queue = output_queue
83+
84+ def _step (self , tensordict , next_tensordict ):
85+ fen = next_tensordict ["fen" ]
86+ self .input_queue .put (fen )
87+ _ , score = self .output_queue .get ()
88+ next_tensordict ["score" ] = torch .tensor (
89+ score , device = "cuda:7" , dtype = torch .bfloat16
90+ )
91+ return next_tensordict
92+
93+ def _reset (self , tensordict , tensordict_reset ):
94+ fen = tensordict_reset ["fen" ]
95+ self .input_queue .put (fen )
96+ _ , score = self .output_queue .get ()
97+ tensordict_reset ["score" ] = torch .tensor (
98+ score , device = "cuda:7" , dtype = torch .bfloat16
99+ )
100+ return tensordict_reset
101+
102+ def transform_observation_spec (self , observation_spec : Composite ):
103+ if not isinstance (observation_spec , Composite ):
104+ raise ValueError (
105+ f"observation_spec was expected to be of type Composite. Got { type (observation_spec )} instead."
106+ )
107+ observation_spec ["observation" ] = TensorSpec (
108+ (), Box (), dtype = torch .bfloat16 , device = "cuda:7"
109+ )
110+ return observation_spec
111+
112+
78113class LLMInputTransform (Transform ):
79114 def __init__ (self , san_moves ):
80115 super ().__init__ ()
@@ -159,9 +194,14 @@ def run_player(input_queue, output_queue):
159194
160195 output = process .stdout .readline ()
161196 if output :
162- # print(f"Output: {output.strip()}")
197+ print (f"Output: { output .strip ()} " )
163198 move = re .search (r"bestmove (.*)" , output .strip ()).group (1 )
164- output_queue .put (move )
199+
200+ output = process .stdout .readline ()
201+ print (f"Output scores: { output .strip ()} " )
202+ score = re .search (r"score (.*)" , output .strip ()).group (1 )
203+
204+ output_queue .put ((move , int (score )))
165205
166206 except queue .Empty :
167207 continue
@@ -179,7 +219,7 @@ def run_player(input_queue, output_queue):
179219def setup_env (input_queue , output_queue , tokenizer ):
180220 def policy_sunfish (td ):
181221 input_queue .put (td ["fen" ])
182- move = output_queue .get ()
222+ move , _ = output_queue .get ()
183223 san = env .board .san (chess .Move .from_uci (move ))
184224 san_idx = env .san_moves .index (san )
185225 td ["action" ] = torch .tensor (san_idx )
@@ -205,6 +245,7 @@ def policy_sunfish(td):
205245 tokenizer = tokenizer ,
206246 )
207247 )
248+ env .append_transform (Score (input_queue , output_queue ))
208249 env .reset ()
209250 return env
210251
@@ -405,16 +446,28 @@ def remove_logits(td):
405446 return_composite = True ,
406447 )
407448
408- class CriticHead (torch .nn .Module ):
449+ # class CriticHead(torch.nn.Module):
450+ # def __init__(self):
451+ # super().__init__()
452+ # self.m = torch.nn.Linear(3584, 1, dtype=torch.bfloat16)
453+
454+ # def forward(self, hidden):
455+ # return self.m(hidden).squeeze(-1).sum(-1, keepdim=True)
456+
457+ # critic_llm_policy = Seq(
458+ # Mod(CriticHead(), in_keys=["hidden"], out_keys=["state_value"]),
459+ # )
460+
461+ class CriticLLMPolicy (torch .nn .Module ):
409462 def __init__ (self ):
410463 super ().__init__ ()
411- self .m = torch .nn .Linear (3584 , 1 , dtype = torch .bfloat16 )
412464
413- def forward (self , hidden ):
414- return self .m (hidden ).squeeze (- 1 ).sum (- 1 , keepdim = True )
465+ def forward (self , score ):
466+ # breakpoint()
467+ return score .unsqueeze (- 1 )
415468
416469 critic_llm_policy = Seq (
417- Mod (CriticHead (), in_keys = ["hidden " ], out_keys = ["state_value" ]),
470+ Mod (CriticLLMPolicy (), in_keys = ["score " ], out_keys = ["state_value" ]),
418471 )
419472
420473 return actor_llm_policy , data_llm_policy , critic_llm_policy , tokenizer
@@ -425,12 +478,21 @@ def play(env, data_llm_policy, actor_llm_policy, tokenizer):
425478
426479 rb = ReplayBuffer (
427480 storage = LazyStackStorage (100 ),
428- batch_size = 8 ,
481+ batch_size = 48 ,
429482 sampler = SliceSamplerWithoutReplacement (slice_len = 8 , end_key = ("next" , "done" )),
430483 )
431484
485+ # def breakpointy(td):
486+ # breakpoint()
487+ # return td
488+
489+ # rb.append_transform(breakpointy)
490+
491+ # Temporarily patched fbcode/pytorch/tensordict/tensordict/_lazy.py?lines=1502
432492 rb .append_transform (lambda td : td .densify (layout = torch .jagged ))
433493
494+ # rb.append_transform(breakpointy)
495+
434496 # obs_tokens in layout=torch.jagged errors with Qwen
435497 # File "/home/mg1998/.conda/envs/rl/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 859, in forward
436498 # cache_position = torch.arange(
@@ -486,6 +548,7 @@ def obs_token_transform(td):
486548
487549 data = gae (data )
488550 loss = loss_module (data )
551+ breakpoint ()
489552 loss .sum (reduce = True ).backward ()
490553 torch .nn .utils .clip_grad_norm_ (loss_module .parameters (), 0.5 )
491554 optim .step ()
0 commit comments