FP16(半精度):在 FP16 中,浮点数使用 16 位表示。它由 1 个符号位、5 位指数和 10 位分数(尾数)组成。此格式提供了更高的精度来表示其范围内的小数值。
BF16 (BFloat16):BF16 也使用 16 位,但分布不同。它有 1 个符号位、8 位指数、7 位尾数。这种格式牺牲了小数部分的一些精度以适应更广泛的指数。
FP16 由于其 10 位尾数而具有较小的范围,但在该范围内精度较高。
BF16 由于其 8 位指数和 7 位尾数而具有较宽的范围,但小数值的精度较低。
下面通过3个案例来说明FP16和BF16的区别。使用TensorFlow来做测试和代码共享在底层:
原始值:0.0001 — 两种方法都可以表示
FP16: 0.00010001659393 (二进制:0|00001|1010001110,十六进制:068E) — 10 个尾数和 5 个指数
BF16: 0.00010013580322 (二进制: 0|01110001|1010010, 十六进制: 38D2) — 7 个尾数和 8 个指数
正如您所看到的,它们具有不同的指数和尾数,因此能够以不同的方式表示。 但是我们可以看到FP16更准确地表达了它,并且值更接近。
原始值:1e-08 (0.00000001)
FP16:0.00000000000000(二进制:0|00000|0000000000,十六进制:0000)
BF16:0.00000001001172 (二进制:0|01100100| 0101100,十六进制:322C)
这是一个非常有趣的案例。 FP16 失败 并使结果为 0,但 BF16 能够用特殊格式表示它。
原始值:100000.00001
FP16:inf(二进制:0|11111|0000000000,十六进制:7C00)
BF16:99840.00000000000000(二进制:0|10001111 |1000011,十六进制:47C3 )
在上述情况下,FP16 失败,因为所有指数位都已满并且不足以表示该值。然而 BF16 有效
FP16常用于深度学习训练和推理,特别是对于需要高精度表示有限范围内的小分数值的任务。
BF16 在专为机器学习任务设计的硬件架构中变得越来越流行,这些任务受益于更广泛的可表示值,即使是以牺牲小数部分的一些精度为代价。当处理大梯度或当大范围内的数值稳定性比小值的精度更重要时,它特别有用。
FP16 为较小范围内的小数值提供了更高的精度,使其适合需要精确表示小数的任务。另一方面,BF16 以牺牲一定精度为代价提供了更广泛的范围,这使得它有利于涉及更广泛值范围或在大范围内的数值稳定性至关重要的任务。 FP16 和 BF16 之间的选择取决于手头机器学习任务的具体要求。
由于上述原因,在进行 Stable Diffusion XL (SDXL) 训练时,FP16 和 BF16 需要的学习率略有不同,我发现 BF16 效果更好。
用于生成上述示例的代码
import tensorflow as tf import struct def float_to_binary(f): return ''.join(f'{b:08b}' for b in struct.pack('>f', f)) def display_fp16(value): fp16 = tf.cast(tf.constant(value, dtype=tf.float32), tf.float16) fp32 = tf.cast(fp16, tf.float32) binary = format(int.from_bytes(fp16.numpy().tobytes(), 'big'), '016b') sign = binary[0] exponent = binary[1:6] fraction = binary[6:] return f"FP16: {fp32.numpy():14.14f} (Binary: {sign}|{exponent}|{fraction}, Hex: {fp16.numpy().view('uint16'):04X})" def display_bf16(value): bf16 = tf.cast(tf.constant(value, dtype=tf.float32), tf.bfloat16) bf32 = tf.cast(bf16, tf.float32) binary = format(int.from_bytes(bf16.numpy().tobytes(), 'big'), '016b') sign = binary[0] exponent = binary[1:9] fraction = binary[9:] return f"BF16: {bf32.numpy():14.14f} (Binary: {sign}|{exponent}|{fraction}, Hex: {bf16.numpy().view('uint16'):04X})" values = [0.0001, 0.00000001, 100000.00001] for value in values: print(f"\nOriginal value: {value}") print(display_fp16(value)) print(display_bf16(value))
免责声明: 提供的所有资源部分来自互联网,如果有侵犯您的版权或其他权益,请说明详细缘由并提供版权或权益证明然后发到邮箱:[email protected] 我们会第一时间内为您处理。
Copyright© 2022 湘ICP备2022001581号-3