|  | 
|  | 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | 
|  | 2 | +
 | 
|  | 3 | +Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +you may not use this file except in compliance with the License. | 
|  | 5 | +You may obtain a copy of the License at | 
|  | 6 | +
 | 
|  | 7 | + http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +
 | 
|  | 9 | +Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +See the License for the specific language governing permissions and | 
|  | 13 | +limitations under the License. */ | 
|  | 14 | + | 
|  | 15 | +#pragma once | 
|  | 16 | +#include <memory> | 
|  | 17 | +#include <random> | 
|  | 18 | +typedef long int64; | 
|  | 19 | +namespace paddle { | 
|  | 20 | +namespace operators { | 
|  | 21 | +namespace math { | 
|  | 22 | + | 
|  | 23 | +// TODO(wanghaoshuang): Support for GPU | 
|  | 24 | + | 
|  | 25 | +/** | 
|  | 26 | +* Sample integers from [0, range). | 
|  | 27 | +*/ | 
|  | 28 | +class Sampler { | 
|  | 29 | + public: | 
|  | 30 | + explicit Sampler(int64 range) : range_(range) { | 
|  | 31 | + PADDLE_ENFORCE_GT(range, 0); | 
|  | 32 | + std::random_device r; | 
|  | 33 | + seed_ = r(); | 
|  | 34 | + } | 
|  | 35 | + explicit Sampler(int64 range, unsigned int seed) | 
|  | 36 | + : range_(range), seed_(seed) { | 
|  | 37 | + PADDLE_ENFORCE_GT(range, 0); | 
|  | 38 | + } | 
|  | 39 | + virtual ~Sampler(); | 
|  | 40 | + // Sample a single value | 
|  | 41 | + virtual int64 Sample() const = 0; | 
|  | 42 | + // The probability that a single call to Sample() returns the given value. | 
|  | 43 | + virtual float Probability(int64 value) const = 0; | 
|  | 44 | + | 
|  | 45 | + int64 range() { return range_; }; | 
|  | 46 | + | 
|  | 47 | + protected: | 
|  | 48 | + const int64 range_; | 
|  | 49 | + unsigned int seed_; | 
|  | 50 | +}; | 
|  | 51 | + | 
|  | 52 | +/** | 
|  | 53 | + * Sample integers from [0, range). | 
|  | 54 | + * And the distribution function is: | 
|  | 55 | + * P(x) = 1 / range | 
|  | 56 | + */ | 
|  | 57 | +class UniformSampler : public Sampler { | 
|  | 58 | + public: | 
|  | 59 | + explicit UniformSampler(int64 range); | 
|  | 60 | + | 
|  | 61 | + explicit UniformSampler(int64 range, unsigned int seed); | 
|  | 62 | + | 
|  | 63 | + ~UniformSampler() override {} | 
|  | 64 | + | 
|  | 65 | + int64 Sample() const override; | 
|  | 66 | + | 
|  | 67 | + float Probability(int64 value) const override; | 
|  | 68 | + | 
|  | 69 | + private: | 
|  | 70 | + const float inv_range_; | 
|  | 71 | + std::shared_ptr<std::mt19937_64> random_engine_; | 
|  | 72 | + std::shared_ptr<std::uniform_int_distribution<>> dist_; | 
|  | 73 | +}; | 
|  | 74 | + | 
|  | 75 | +/** | 
|  | 76 | + * Sample integers from [0, range). | 
|  | 77 | + * And the distribution function is: | 
|  | 78 | + * P(x) = (1/ln(range+1)) * ln(1 + 1/(x + 1)) | 
|  | 79 | + */ | 
|  | 80 | +class LogUniformSampler : public Sampler { | 
|  | 81 | + public: | 
|  | 82 | + explicit LogUniformSampler(int64 range); | 
|  | 83 | + | 
|  | 84 | + explicit LogUniformSampler(int64 range, unsigned int seed); | 
|  | 85 | + | 
|  | 86 | + ~LogUniformSampler() override {} | 
|  | 87 | + | 
|  | 88 | + int64 Sample() const override; | 
|  | 89 | + | 
|  | 90 | + float Probability(int64 value) const override; | 
|  | 91 | + | 
|  | 92 | + private: | 
|  | 93 | + const float log_range_; | 
|  | 94 | + std::shared_ptr<std::mt19937_64> random_engine_; | 
|  | 95 | + std::shared_ptr<std::uniform_real_distribution<>> dist_; | 
|  | 96 | +}; | 
|  | 97 | + | 
|  | 98 | +} // math | 
|  | 99 | +} // namespace operators | 
|  | 100 | +} // namespace paddle | 
0 commit comments