Hello Here is a custom Env. I would like to discussmy mystakes because it pass “check_env_specs” (see Tests at the bottom of the code) but not after transformations…
CosTrader.py
import torch from tensordict.tensordict import TensorDict, TensorDictBase from torchrl.data import Bounded, Composite, Unbounded, Categorical, Binary from torchrl.envs import EnvBase from torchrl.envs.utils import check_env_specs from typing import Optional ############# rendering = True ############# if rendering: import pygame from pygame import gfxdraw def gen_params(batch_size=None) -> TensorDictBase: if batch_size is None: batch_size = [] # f(x) = a.cos(b(x-h))+k td = TensorDict({"params": TensorDict({"a": 0.2, # ampli "b": 1, # freq "h": 0.5, # phase "k": 1.25, # niveau "dt": 0.05}, [],) }, [],) if batch_size: td = td.expand(batch_size).contiguous() return td def make_composite_from_td(td): composite = Composite({key: make_composite_from_td(tensor) if isinstance(tensor, TensorDictBase) else Unbounded( dtype=tensor.dtype, device=tensor.device, shape=tensor.shape ) for key, tensor in td.items()}, shape=td.shape ) return composite class CosTraderEnv(EnvBase): batch_locked = False AUTO_UNWRAP_TRANSFORMED_ENV = False def __init__(self, td_params=None, seed=None, device="cpu", batch_size = []): super().__init__(device=device, batch_size=batch_size) if td_params is None: td_params = self.gen_params(batch_size) self._make_spec(td_params) if seed is None: seed = torch.empty((), dtype=torch.int64).random_().item() self.set_seed(seed) self.screen = None self.clock = None def _step(self,tensordict): a = tensordict["params", "a"] b = tensordict["params", "b"] h = tensordict["params", "h"] k = tensordict["params", "k"] dt = tensordict["params", "dt"] bid = tensordict["bid"] # cours angle = tensordict["angle"] # tg à la courbe posA = tensordict["posA"] # position Achat posV = tensordict["posV"] # position Vente solde = tensordict["solde"] # solde t = tensordict["t"] # temps action_mask= tensordict["action_mask"] u = tensordict["action"].squeeze(-1) #action: 0 = close, 1 = rien, 2 = Achat C = torch.where(u == 0, 1, 0) R = torch.where(u == 1, 1, 0) A = torch.where(u == 2, 1, 0) V = torch.where(u == 3, 1, 0) no_pos = torch.where((posA == 0)&(posV == 0), 1, 0) in_posA = torch.where(posA > 0, 1, 0) in_posV = torch.where(posV > 0, 1, 0) addA = torch.where(u == 2, bid*no_pos, 0) addV = torch.where(u == 3, bid*no_pos, 0) # add new pos to old pos new_posA = posA + addA new_posV = posV + addV # remove closed pos new_posA = torch.where(u == 0, 0, new_posA) new_posV = torch.where(u == 0, 0, new_posV) new_t = t + dt new_bid = a*(b*(t-h)).cos()+k # f(x) = a.cos(b(x-h))+k new_angle = (new_bid - bid)/dt new_solde = new_bid - posA - posV ##################### REWARD ######################## reward_C = C * in_posA *(bid - (posA + posV)) #si posA >0 => posV=0 et inv reward_R = (R* (in_posA + in_posV) - R* no_pos) *(new_bid-bid) reward_A = A* no_pos *(new_bid-bid) reward_V = V* no_pos *(new_bid-bid) reward = reward_C *10+ reward_R + reward_A + reward_V #### mask adaptation action en cours lut = torch.tensor([ [False, True, True, True], # action 0 close [False, True, True, True], # action 1 rien [True, True, False, False], # action 2 achat [True, True, False, False], # action 3 vente ], dtype=torch.bool) new_mask = lut[u] #done = C.bool() ? done = torch.zeros_like(reward, dtype=torch.bool) nextTD = TensorDict({"params": tensordict["params"], "angle": new_angle, "bid": new_bid, "posA": new_posA, "posV": new_posV, "t": new_t, "solde":new_solde, "reward": reward, "done": done, "action_mask":new_mask}, tensordict.shape,) if rendering: self.state = (angle.tolist(),bid.tolist(),new_posA.tolist(),new_posV.tolist()) self.last_u = 0 self.render() return nextTD def _reset(self, tensordict): if tensordict is None or tensordict.is_empty(): tensordict = self.gen_params(batch_size=self.batch_size) t = torch.zeros(tensordict.shape, device = self.device) posA = torch.zeros(tensordict.shape, device = self.device) posV = torch.zeros(tensordict.shape, device = self.device) solde= torch.zeros(tensordict.shape, device = self.device) a = tensordict["params", "a"] b = tensordict["params", "b"] h = torch.rand(tensordict.shape)*4-2 k = tensordict["params", "k"] dt = tensordict["params", "dt"] tensordict["params", "h"]=h bid = a* (b*(t-h)).cos() + k #a.cos(b(x-h))+k angle = torch.zeros(tensordict.shape, device = self.device) #... new_action_mask = self._make_action_mask() out = TensorDict({"params": tensordict["params"], "angle": angle, "bid": bid, "posA": posA, "posV": posV, "t": t, "solde":solde, "action_mask":new_action_mask}, batch_size=tensordict.shape,) if rendering: self.last_u = None self.state = (angle.tolist(),bid.tolist(),posA.tolist(),posV.tolist()) self.render() return out def _make_spec(self, td_params): self.batch_size = getattr(self, "batch_size", td_params.shape) if not isinstance(self.batch_size, torch.Size): self.batch_size = torch.Size(self.batch_size) # Action spec self.action_spec = Categorical( n=4, shape=(*self.batch_size, 1) ) # Observation spec self.observation_spec = Composite( angle=Bounded( low=-torch.pi/2, high=torch.pi/2, shape=self.batch_size, dtype=torch.float32 ), bid=Unbounded( shape=self.batch_size, dtype=torch.float32 ), posA=Unbounded( shape=self.batch_size, dtype=torch.float32 ), posV=Unbounded( shape=self.batch_size, dtype=torch.float32 ), t=Bounded( low=0, high=1000, shape=self.batch_size, dtype=torch.float32 ), solde=Unbounded( shape=self.batch_size, dtype=torch.float32 ), action_mask=Binary( n=4, dtype=torch.bool, shape=(*self.batch_size, 4) ), params=make_composite_from_td(td_params["params"]), shape=self.batch_size ) self.state_spec = self.observation_spec.clone() self.reward_spec = Unbounded( shape=(*self.batch_size, 1) ) gen_params = staticmethod(gen_params) def _make_action_mask(self): mask = torch.tensor([False, True, True, True]) # n=4 mask = mask.expand(*self.batch_size, 4) return mask def _set_seed(self, seed: Optional[int]): rng = torch.manual_seed(seed) self.rng = rng def get_obskeys(self): return ["angle", "bid","posA","t","solde"] def render(self): if self.screen is None: self.screen_dim = 400 pygame.init() pygame.display.init() self.screen = pygame.display.set_mode((self.screen_dim,self.screen_dim)) if self.clock is None: self.clock = pygame.time.Clock() for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() raise SystemExit self.surf = pygame.Surface((self.screen_dim, self.screen_dim)) self.surf.fill((255, 255, 255)) bound = 2.2 scale = self.screen_dim / (bound * 2) offset = self.screen_dim // 2 rod_length = 1 * scale rod_width = 0.02 * scale l, r, t, b = 0, rod_length, rod_width / 2, -rod_width / 2 coords = [(l, b), (l, t), (r, t), (r, b)] transformed_coords = [] for c in coords: try: c = pygame.math.Vector2(c).rotate_rad(self.state[0])# + np.pi / 2) except: c = pygame.math.Vector2(c).rotate_rad(self.state[0][0])# + np.pi / 2) c = (c[0] + offset, c[1] + offset) transformed_coords.append(c) gfxdraw.aapolygon(self.surf, transformed_coords, (204, 77, 77)) gfxdraw.filled_polygon(self.surf, transformed_coords, (204, 77, 77)) gfxdraw.aacircle(self.surf, offset, offset, int(rod_width / 2), (204, 77, 77)) gfxdraw.filled_circle( self.surf, offset, offset, int(rod_width / 2), (204, 77, 77) ) # drawing bid gfxdraw.filled_circle(self.surf, 10,int(9*( 1.4 -1)* scale), int(0.05 * scale), (0, 0, 255)) try: bid = self.state[1] gfxdraw.filled_circle(self.surf, 5,int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0)) except: bid = self.state[1][0] #gfxdraw.aacircle(self.surf, 5, int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0)) gfxdraw.filled_circle(self.surf, 5,int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0)) # drawing posA try: vac = self.state[2] gfxdraw.filled_circle(self.surf, self.screen_dim-5,int(9*(vac-1)* scale), int(0.05 * scale), (0, 255, 0)) except: vac = self.state[2][0] #gfxdraw.aacircle(self.surf, 5, int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0)) gfxdraw.filled_circle(self.surf, self.screen_dim-5,int(9*(vac-1)* scale), int(0.05 * scale), (0, 255, 0)) # drawing posV try: vac = self.state[3] gfxdraw.filled_circle(self.surf, self.screen_dim-5,int(9*(vac-1)* scale), int(0.05 * scale), (255, 0, 0)) except: vac = self.state[3][0] #gfxdraw.aacircle(self.surf, 5, int(9*(bid-1)* scale), int(0.05 * scale), (0, 0, 0)) gfxdraw.filled_circle(self.surf, self.screen_dim-5,int(9*(vac-1)* scale), int(0.05 * scale), (255, 0, 0)) #print(bid,vac) self.surf = pygame.transform.flip(self.surf, False, True) self.screen.blit(self.surf, (0, 0)) self.clock.tick(30) pygame.display.flip() def close(self): if self.screen is not None: pygame.display.quit() pygame.quit() if __name__ == "__main__": ### Tests ### env = CosTraderEnv(batch_size = torch.Size([])) check_env_specs(env) env = CosTraderEnv(batch_size = torch.Size([10])) check_env_specs(env) print("\n* observation_spec:", env.observation_spec) print("\n* action_spec:", env.action_spec) print("\n* reward_spec:", env.reward_spec) print("\n* random action 5: \n", env.action_spec.rand(torch.Size([5]))) print("\n* random obs 5: \n", env.observation_spec.rand(torch.Size([5]))) td = env.reset() print("\n* reset tensordict", td) td = env.rand_step(td) print("\n* random step tensordict", td) env.close()