Skip to content

Commit 64c8e92

Browse files
committed
Fix
1 parent 9c1900c commit 64c8e92

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+996
-934
lines changed

paddle/phi/kernels/funcs/binomial_functor.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ inline T stirling_approx_tail(int64_t k) {
4040
}
4141

4242
template <typename T, typename Context>
43-
inline int64_t btrs(const Context& ctx, const T n, const T p) {
43+
inline int64_t btrs(const Context& dev_ctx, const T n, const T p) {
4444
int64_t k;
4545
T U, V, us;
4646
std::uniform_real_distribution<T> dist(0.0, 1.0);
47-
auto gen_ptr = ctx.GetGenerator();
47+
auto gen_ptr = dev_ctx.GetGenerator();
4848
auto engine = gen_ptr->GetCPUEngine();
4949

5050
const T stddev = std::sqrt(n * p * (1 - p));
@@ -87,13 +87,15 @@ inline int64_t btrs(const Context& ctx, const T n, const T p) {
8787
}
8888

8989
template <typename T, typename Context>
90-
inline int64_t binomial_inversion(const Context& ctx, const T n, const T p) {
90+
inline int64_t binomial_inversion(const Context& dev_ctx,
91+
const T n,
92+
const T p) {
9193
T unif;
9294
T geom_sum = 0.0;
9395
int64_t num_geom = 0;
9496
T logprob = std::log1p(-p);
9597
std::uniform_real_distribution<T> dist(0.0, 1.0);
96-
auto gen_ptr = ctx.GetGenerator();
98+
auto gen_ptr = dev_ctx.GetGenerator();
9799
auto engine = gen_ptr->GetCPUEngine();
98100

99101
while (1) {
@@ -109,23 +111,23 @@ inline int64_t binomial_inversion(const Context& ctx, const T n, const T p) {
109111
}
110112

111113
template <typename T, typename Context>
112-
inline int64_t BinomialFunctor(const Context& ctx, const T n, const T p) {
114+
inline int64_t BinomialFunctor(const Context& dev_ctx, const T n, const T p) {
113115
if (n <= 0.0 || p <= 0.0) {
114116
return 0;
115117
} else if (p >= 1.0) {
116118
return static_cast<int64_t>(n);
117119
} else if (p <= 0.5) {
118120
if (n * p >= 10.0) {
119-
return btrs<T>(ctx, n, p);
121+
return btrs<T>(dev_ctx, n, p);
120122
} else {
121-
return binomial_inversion<T>(ctx, n, p);
123+
return binomial_inversion<T>(dev_ctx, n, p);
122124
}
123125
} else {
124126
T qprob = 1.0 - p;
125127
if (n * qprob >= 10.0) {
126-
return static_cast<int64_t>(n) - btrs<T>(ctx, n, qprob);
128+
return static_cast<int64_t>(n) - btrs<T>(dev_ctx, n, qprob);
127129
} else {
128-
return static_cast<int64_t>(n) - binomial_inversion<T>(ctx, n, qprob);
130+
return static_cast<int64_t>(n) - binomial_inversion<T>(dev_ctx, n, qprob);
129131
}
130132
}
131133
}

paddle/phi/kernels/funcs/broadcast_function.h

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -436,15 +436,15 @@ __global__ void VectorizedBroadcastKernel(
436436

437437
template <typename OutT, typename Functor, int Arity, int NumOuts, int VecSize>
438438
void LaunchBroadcastKernel(
439-
const KPDevice &ctx,
439+
const KPDevice &dev_ctx,
440440
const BroadcastTypeClassifier<OutT, Functor, Arity, NumOuts> &classifier,
441441
Functor func) {
442442
#ifdef PADDLE_WITH_XPU_KP
443443
int numel = classifier.numel;
444444
const int threads = 64;
445445
const int blocks = 8;
446446
int read_lens = configs[0].buf_len;
447-
auto stream = ctx.x_context()->xpu_stream;
447+
auto stream = dev_ctx.x_context()->xpu_stream;
448448
int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
449449
int tail_tid = numel % (read_lens * threads);
450450

@@ -461,8 +461,8 @@ void LaunchBroadcastKernel(
461461
#else
462462
const int &numel = classifier.numel;
463463
auto gpu_config =
464-
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
465-
auto stream = ctx.stream();
464+
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, VecSize);
465+
auto stream = dev_ctx.stream();
466466
auto threads = gpu_config.GetBlockSize();
467467
auto blocks = gpu_config.block_per_grid;
468468
int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
@@ -513,20 +513,20 @@ void LaunchBroadcastKernel(
513513

514514
template <typename OutT, typename Functor, int Arity, int NumOuts = 1>
515515
typename std::enable_if<!NeedVectorized<OutT>::value, void>::type
516-
BroadcastKernelForDifferentVecSize(const KPDevice &ctx,
516+
BroadcastKernelForDifferentVecSize(const KPDevice &dev_ctx,
517517
const std::vector<const DenseTensor *> &ins,
518518
std::vector<DenseTensor *> *outs,
519519
int axis,
520520
Functor func) {
521521
auto classifier =
522522
BroadcastTypeClassifier<OutT, Functor, Arity, NumOuts>(ins, outs, axis);
523523
LaunchBroadcastKernel<OutT, Functor, Arity, NumOuts, VecSizeS>(
524-
ctx, classifier, func);
524+
dev_ctx, classifier, func);
525525
}
526526

527527
template <typename OutT, typename Functor, int Arity, int NumOuts = 1>
528528
typename std::enable_if<NeedVectorized<OutT>::value, void>::type
529-
BroadcastKernelForDifferentVecSize(const KPDevice &ctx,
529+
BroadcastKernelForDifferentVecSize(const KPDevice &dev_ctx,
530530
const std::vector<const DenseTensor *> &ins,
531531
std::vector<DenseTensor *> *outs,
532532
int axis,
@@ -545,17 +545,17 @@ BroadcastKernelForDifferentVecSize(const KPDevice &ctx,
545545
switch (vec_size) {
546546
case VecSizeL: {
547547
LaunchBroadcastKernel<OutT, Functor, Arity, NumOuts, VecSizeL>(
548-
ctx, classifier, func);
548+
dev_ctx, classifier, func);
549549
break;
550550
}
551551
case VecSizeM: {
552552
LaunchBroadcastKernel<OutT, Functor, Arity, NumOuts, VecSizeM>(
553-
ctx, classifier, func);
553+
dev_ctx, classifier, func);
554554
break;
555555
}
556556
case VecSizeS: {
557557
LaunchBroadcastKernel<OutT, Functor, Arity, NumOuts, VecSizeS>(
558-
ctx, classifier, func);
558+
dev_ctx, classifier, func);
559559
break;
560560
}
561561
default: {
@@ -591,7 +591,7 @@ static void SliceTensor(DenseTensor *x,
591591
}
592592

593593
template <typename OutT, typename Functor, int kArity, int NumOuts = 1>
594-
void BroadcastKernelSplit(const KPDevice &ctx,
594+
void BroadcastKernelSplit(const KPDevice &dev_ctx,
595595
const std::vector<const DenseTensor *> &ins,
596596
std::vector<DenseTensor *> *outs,
597597
int axis,
@@ -728,12 +728,12 @@ void BroadcastKernelSplit(const KPDevice &ctx,
728728
}
729729

730730
BroadcastKernelForDifferentVecSize<OutT, Functor, kArity, NumOuts>(
731-
ctx, new_ins, &new_outs, axis, func);
731+
dev_ctx, new_ins, &new_outs, axis, func);
732732
}
733733
}
734734

735735
template <typename OutT, typename Functor, int kArity, int NumOuts = 1>
736-
void BroadcastKernelApply(const KPDevice &ctx,
736+
void BroadcastKernelApply(const KPDevice &dev_ctx,
737737
const std::vector<const DenseTensor *> &ins,
738738
std::vector<DenseTensor *> *outs,
739739
int axis,
@@ -748,16 +748,16 @@ void BroadcastKernelApply(const KPDevice &ctx,
748748
}
749749
if (use_int64_index_kernel) { // use_int64_index_kernel
750750
BroadcastKernelSplit<OutT, Functor, kArity, NumOuts>(
751-
ctx, ins, outs, axis, func, compute_size);
751+
dev_ctx, ins, outs, axis, func, compute_size);
752752
return;
753753
}
754754
#endif
755755
BroadcastKernelForDifferentVecSize<OutT, Functor, kArity, NumOuts>(
756-
ctx, ins, outs, axis, func);
756+
dev_ctx, ins, outs, axis, func);
757757
}
758758

759759
template <typename OutT, typename Functor, int NumOuts = 1>
760-
void BroadcastKernel(const KPDevice &ctx,
760+
void BroadcastKernel(const KPDevice &dev_ctx,
761761
const std::vector<const DenseTensor *> &ins,
762762
std::vector<DenseTensor *> *outs,
763763
Functor func,
@@ -805,7 +805,7 @@ void BroadcastKernel(const KPDevice &ctx,
805805
"%d-th output tensor`s shape is not.",
806806
i));
807807
}
808-
ctx.template Alloc<OutT>((*outs)[i]);
808+
dev_ctx.template Alloc<OutT>((*outs)[i]);
809809
}
810810
if ((*outs)[0]->numel() == 0) {
811811
return;
@@ -823,7 +823,7 @@ void BroadcastKernel(const KPDevice &ctx,
823823
}
824824
axis = axis == -1 ? max_rank - min_rank : axis;
825825
BroadcastKernelApply<OutT, Functor, kArity, NumOuts>(
826-
ctx, ins, outs, axis, func);
826+
dev_ctx, ins, outs, axis, func);
827827
}
828828

829829
template <typename Functor, typename T, typename OutType = T>

0 commit comments

Comments
 (0)