Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions TROUBLESHOOTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,25 +334,6 @@ only be enabled for debugging.
by one. This is useful to bypass the long compilation time but overall step time will be a lot slower and memory usage will be higher
since all compiler optimizaiton will be skipped.

* ```XLA_USE_BF16```: If set to 1, transforms all the _PyTorch_ _Float_ values into _BiFloat16_
when sending to the _TPU_ device. Note that when using `XLA_USE_BF16=1` tensor arithmetic will
be done in reduced precision and so tensors will not be accurate if accumulated over time.
For example:

```
# In reduced bfloat16 precision
>>> torch.tensor(4096, dtype=torch.bfloat16) + torch.tensor(1, dtype=torch.bfloat16)
tensor(4096., dtype=torch.bfloat16)
# Whereas in full float32 precision
>>> torch.tensor(4096) + torch.tensor(1)
tensor(4097)
```
So to get accurate metrics such as average loss value over many steps, use manual mixed
precision where metrics stay in FP32.

* ```XLA_USE_F16```: If set to 1, transforms all the _PyTorch_ _Float_ values into _Float16_
(_PyTorch_ _Half_ type) when sending to devices which supports them.

* ```TF_CPP_LOG_THREAD_ID```: If set to 1, the TF logs will show the thread ID
helping with debugging multithreaded processes.

Expand Down
14 changes: 14 additions & 0 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ namespace {
bool ShouldUseBF16() {
bool use_bf16 = runtime::sys_util::GetEnvBool("XLA_USE_BF16", false);
if (use_bf16) {
std::cout
<< "XLA_USE_BF16 will be deprecated after the 2.4 release, please "
"convert your model to bf16 directly\n";
TF_LOG(INFO) << "Using BF16 data type for floating point values";
}
return use_bf16;
Expand All @@ -18,6 +21,9 @@ bool ShouldUseBF16() {
bool ShouldUseF16() {
bool use_fp16 = runtime::sys_util::GetEnvBool("XLA_USE_FP16", false);
if (use_fp16) {
std::cout
<< "XLA_USE_FP16 will be deprecated after the 2.4 release, please "
"convert your model to fp16 directly\n";
TF_LOG(INFO) << "Using F16 data type for floating point values";
}
return use_fp16;
Expand All @@ -27,6 +33,9 @@ bool ShouldDowncastToBF16() {
bool downcast_bf16 =
runtime::sys_util::GetEnvBool("XLA_DOWNCAST_BF16", false);
if (downcast_bf16) {
std::cout
<< "XLA_DOWNCAST_BF16 will be deprecated after the 2.4 release, please "
"downcast your model directly\n";
TF_LOG(INFO) << "Downcasting floating point values, F64->F32, F32->BF16";
}
return downcast_bf16;
Expand All @@ -36,6 +45,9 @@ bool ShouldDowncastToF16() {
bool downcast_fp16 =
runtime::sys_util::GetEnvBool("XLA_DOWNCAST_FP16", false);
if (downcast_fp16) {
std::cout
<< "XLA_DOWNCAST_FP16 will be deprecated after the 2.4 release, please "
"downcast your model directly\n";
TF_LOG(INFO) << "Downcasting floating point values, F64->F32, F32->FP16";
}
return downcast_fp16;
Expand All @@ -45,6 +57,8 @@ bool ShouldUse32BitLong() {
bool use_32bit_long =
runtime::sys_util::GetEnvBool("XLA_USE_32BIT_LONG", false);
if (use_32bit_long) {
std::cout
<< "XLA_USE_32BIT_LONG will be deprecated after the 2.4 release\n";
TF_LOG(INFO) << "Using 32bit integers for kLong values";
}
return use_32bit_long;
Expand Down