|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | + |
| 4 | +import gym |
| 5 | +import numpy as np |
| 6 | + |
| 7 | + |
| 8 | +class AverageMeter: |
| 9 | + """Computes and stores the average and current value.""" |
| 10 | + |
| 11 | + def __init__(self, name, fmt=":f"): |
| 12 | + self.name = name |
| 13 | + self.fmt = fmt |
| 14 | + self.reset() |
| 15 | + |
| 16 | + def reset(self): |
| 17 | + self.val = 0 |
| 18 | + self.avg = 0 |
| 19 | + self.sum = 0 |
| 20 | + self.count = 0 |
| 21 | + |
| 22 | + def update(self, val, n=1): |
| 23 | + self.val = val |
| 24 | + self.sum += val * n |
| 25 | + self.count += n |
| 26 | + self.avg = self.sum / self.count |
| 27 | + |
| 28 | + def __str__(self): |
| 29 | + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" |
| 30 | + return fmtstr.format(**self.__dict__) |
| 31 | + |
| 32 | + |
| 33 | +class ProgressMeter: |
| 34 | + def __init__(self, num_batches, *meters, prefix=""): |
| 35 | + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
| 36 | + self.meters = meters |
| 37 | + self.prefix = prefix |
| 38 | + |
| 39 | + def print(self, batch): |
| 40 | + entries = [self.prefix + self.batch_fmtstr.format(batch)] |
| 41 | + entries += [str(meter) for meter in self.meters] |
| 42 | + print("\t".join(entries)) |
| 43 | + |
| 44 | + def _get_batch_fmtstr(self, num_batches): |
| 45 | + num_digits = len(str(num_batches // 1)) |
| 46 | + fmt = "{:" + str(num_digits) + "d}" |
| 47 | + return "[" + fmt + "/" + fmt.format(num_batches) + "]" |
| 48 | + |
| 49 | + |
| 50 | +env = gym.make("KellyCoinflip-v0") |
| 51 | + |
| 52 | +for p in np.arange(100 + 1): |
| 53 | + rs = AverageMeter("Rewards", ":f") |
| 54 | + for _ in range(1000): |
| 55 | + done, s = False, env.reset() |
| 56 | + |
| 57 | + while not done: |
| 58 | + s_, r, done, _ = env.step(p * env.wealth) |
| 59 | + s = s_ |
| 60 | + rs.update(r) |
| 61 | + print(p, rs) |
0 commit comments