Estaba revisando el documento BERT que usa GELU (Unidad lineal de error gaussiano) que establece la ecuación como
que a su vez se aproxima a
¿Podría simplificar la ecuación y explicar cómo se ha aproximado?
fuente
Estaba revisando el documento BERT que usa GELU (Unidad lineal de error gaussiano) que establece la ecuación como
que a su vez se aproxima a
¿Podría simplificar la ecuación y explicar cómo se ha aproximado?
Podemos expandir la distribución acumulativa de , es decir, , de la siguiente manera:
Tenga en cuenta que esta es una definición , no una ecuación (o una relación). Los autores han proporcionado algunas justificaciones para esta propuesta, por ejemplo, una analogía estocástica , aunque matemáticamente, esta es solo una definición.
Aquí está la trama de GELU:
Para este tipo de aproximaciones numéricas, la idea clave es encontrar una función similar (principalmente basada en la experiencia), parametrizarla y luego ajustarla a un conjunto de puntos de la función original.
Sabiendo que está muy cerca de
y la primera derivada de coincide con la de en , que es , procedemos a ajustar
(o con más términos) a un conjunto de puntos .
He ajustado esta función a 20 muestras entre ( usando este sitio ), y aquí están los coeficientes:
Al establecer , se estimó en . Con más muestras de un rango más amplio (ese sitio solo permitió 20), el coeficiente estará más cerca del del papel . Finalmente llegamos
con error cuadrático medio para .
Tenga en cuenta que si no utilizamos la relación entre las primeras derivadas, el término se habría incluido en los parámetros de la siguiente manera
que es menos bella (menos analítica, más numérica).
Como lo sugiere @BookYourLuck , podemos utilizar la paridad de funciones para restringir el espacio de los polinomios en los que buscamos. Es decir, dado que es una función extraña, es decir, , y también es una función extraña, función polinómica dentro también debe ser impar (solo debe tener poderes impares de ) para tener
Anteriormente, tuvimos la suerte de terminar con (casi) cero los coeficientes de potencias pares y , sin embargo, en general, esto podría dar lugar a aproximaciones de baja calidad que, por ejemplo, tienen un término como que se cancela mediante términos adicionales (pares o impares) en lugar de simplemente optar por .
Una relación similar se mantiene entre y (sigmoid), que se propone en el documento como otra aproximación, con error cuadrático medio para .
Aquí hay un código de Python para generar puntos de datos, ajustar las funciones y calcular los errores cuadrados medios:
import math
import numpy as np
import scipy.optimize as optimize
def tahn(xs, a):
return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]
def sigmoid(xs, a):
return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]
print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
# .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])
# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))
sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))
# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()
# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()
print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)
if print_points == 1:
print(len(xs))
for x, erf in zip(xs, erfs):
print(x, erf)
Salida:
Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05
fuente