Durante 50 años, desde la época de Kernighan, Ritchie y su primera edición del libro en lenguaje C, se sabía que un tipo “flotante” de precisión simple tiene un tamaño de 32 bits y un tipo de precisión doble tiene 64 bits. También existía un tipo “doble largo” de 80 bits con precisión extendida, y todos estos tipos cubrían casi todas las necesidades del procesamiento de datos de punto flotante. Sin embargo, durante los últimos años, la llegada de grandes modelos de redes neuronales obligó a los desarrolladores a pasar a otra parte del espectro y reducir los tipos de punto flotante tanto como fuera posible.
Honestamente, me sorprendió cuando descubrí que existe el formato de punto flotante de 4 bits. ¿Cómo diablos puede ser posible? La mejor manera de saberlo es probarlo por nuestra cuenta. En este artículo, descubriremos los formatos de punto flotante más populares, crearemos una red neuronal simple y veremos cómo funciona.
Empecemos.
Un punto flotante “estándar” de 32 bits
Antes de entrar en formatos “extremos”, recordemos uno estándar. Un IEEE 754 El estándar para aritmética de punto flotante fue establecido en 1985 por el Instituto de Ingenieros Eléctricos y Electrónicos (IEEE). Un número típico en un tipo de 32 flotantes se ve así:
Aquí, el primer bit es un signo, los siguientes 8 bits representan un exponente y los últimos bits representan la mantisa. El valor final se calcula mediante la fórmula:
Esta sencilla función auxiliar nos permite imprimir un valor de punto flotante en forma binaria:
import structdef print_float32(val: float):
""" Print Float32 in a binary form """
m = struct.unpack('I', struct.pack('f', val))[0]
return format(m, 'b').zfill(32)
print_float32(0.15625)
# > 00111110001000000000000000000000
También creemos otro asistente para la conversión hacia atrás, que será útil más adelante:
def ieee_754_conversion(sign, exponent_raw, mantissa, exp_len=8, mant_len=23):
""" Convert binary data into the floating point value """
sign_mult = -1 if sign == 1 else 1
exponent = exponent_raw - (2 ** (exp_len - 1) - 1)
mant_mult = 1
for b in range(mant_len - 1, -1, -1):
if mantissa & (2 **…