설치를 할 때 먼저 jaxlib을 설치하고 그 버전에 맞는 jax를 설치합니다.

! pip install jaxlib
! pip install jax==0.3.15
from jax import numpy as jnp
from jax import grad
import numpy as np
from plotnine import *
import pandas as pd
from tqdm.notebook import tqdm

간단한 모형을 만들어 봅니다. X와 y가 2차함수 형태로 결합된 경우를 생각해 봅니다.

n = 100
X = np.random.uniform(0, 3, size=n)
y = 3 * np.power(X, 2) + np.random.normal(10, 3, size=n)

data = pd.DataFrame(zip(X, y), columns=["X", "y"])
(
    ggplot(data)
    + aes("X", "y")
    + geom_point()
)
<ggplot: (8762106993550)>

선형 모델을 먼저 생각해 봅니다.

w = {"a": 0., "b": 0.}

# set model
def model(w, X):
    return w["a"] * X + w["b"]

# set loss
def loss(w, model, X, y):
    return jnp.power(model(w, X) - y, 2).sum()

# grad loss
dloss = grad(loss)

이제 경사하강법을 활용하여 w를 찾아봅니다.

경사하강법은 말그대로 경사를 구해서 낮은 쪽으로 이동하게 하는 것입니다.

기본적인 아이디어는 예측치와 관측값의 차이를 합치는 손실함수(loss function)을 구합니다. 그리고 파라미터를 손실이 줄어드는 방향(경사, 미분해서 보통 구합니다)으로 조금씩 옮겨가면서 최적의 값을 찾아 한발 한발 나아가는 방식입니다.

수식으로 간단하게 표기해보자면 $$L(\theta) = \sum_i ( f(x_i; \theta) - y)^2$$ 로 정의하고 이 $L$을 $\theta$로 미분해 해당 미분값(경사)를 이용해서 낮추는 방향으로 파라미터 $\theta$를 바꿔가면서 찾아가는 방식입니다.

rate = 0.0001

losses = []
ws = []
for i in tqdm(range(2000)):
    l = loss(w, model, X, y)
    ws.append(w.copy())
    losses.append(l)
    if i % 100 == 0:
        print(i, "loss: ", l)
    dw = dloss(w, model, X, y)
    for key in w.keys():
        w[key] -= dw[key]*rate
0 loss:  51256.832
100 loss:  1503.3762
200 loss:  1503.3103
300 loss:  1503.2767
400 loss:  1503.259
500 loss:  1503.25
600 loss:  1503.2452
700 loss:  1503.2428
800 loss:  1503.2415
900 loss:  1503.2407
1000 loss:  1503.2406
1100 loss:  1503.2402
1200 loss:  1503.2402
1300 loss:  1503.24
1400 loss:  1503.24
1500 loss:  1503.24
1600 loss:  1503.24
1700 loss:  1503.24
1800 loss:  1503.24
1900 loss:  1503.24
result_df = pd.DataFrame(zip(X, np.array(model(w, X))), columns=["X", "f"])
(
    ggplot(data=data) +
    aes("X", "y") +
    geom_point() +
    geom_smooth(method="lm") +
    geom_line(data=result_df, mapping=aes("X", "f"),  color="#ff1234")
    
)
<ggplot: (8762106674460)>
dfs = [pd.DataFrame(zip(map(int, np.ones_like(X)*i), X, np.array(model(ws[i], X))), columns=["i", "X", "f"]) for i in range(0, 50, 5)]
df = pd.concat(dfs)

처음에는 많이 차이나지만 점점 해석적으로 계산한 선형 회귀 값과 유사해지는 것을 볼 수 있습니다.

이 경사하강법의 장점은 손실함수를 정의 할 수만 있다면 적용할 수 있어 유연하게 많은 곳에 적용할 수 있습니다.

p = (
    ggplot(data=df) +
    aes(x="X", y="f") +
    geom_point(data=data, mapping=aes("X", "y")) +
    geom_smooth(data=data, method="lm", mapping=aes("X", "y"), color="yellow") +
    geom_line(color="red", size=1) +
    facet_wrap("i")
)
p
<ggplot: (8762106626557)>

하나의 그래프에 겹쳐서 표현하면 아래와 같은 그래프가 됩니다.

p = (ggplot() +
    geom_point(data=data, mapping=aes("X", "y")) +
    geom_smooth(data=data, method="lm", mapping=aes("X", "y"), color="yellow")
)

for df in dfs:
    p += geom_line(data=df, mapping=aes(x="X", y="f", color="i"))


p
<ggplot: (8762106770219)>
Bad pipe message: %s [b'\xecz\x02\x84\x82m\x86XAW\xd5O\x11\xd9T\xd2\x8eH *&\xe0\xb7\xc4y\x1bFH\x86l:\x0f\r\r[\x91\xf9Wgs\xd0+/\xcf\xfc\xc3Z<.\xd0\xb6\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00']
Bad pipe message: %s [b'#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03']
Bad pipe message: %s [b'\x08\x08\x08\t\x08\n\x08', b'\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06']
Bad pipe message: %s [b"'\xd0\x01]Pb\xa7r\x8eY\xce\xeb\xd4\xa2v\xb8\x7f\xfe\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0#\xc0'\x00g\x00@\xc0\n\xc0\x14\x009\x008\xc0\t\xc0\x13\x003\x002\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c\xc0P\x00=\x00<\x005\x00/\x00\x9a\x00\x99\xc0\x07\xc0\x11\x00\x96\x00\x05\x00\xff"]
Bad pipe message: %s [b'']
Bad pipe message: %s [b'']
Bad pipe message: %s [b'\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00 \x14\x9al(\xf3PpR\x01<\xeao\x115\x0b"f\xbb\x1f\xb1t\xbb']
Bad pipe message: %s [b'\xd2\xc1\x1e+D\xa7f#\x10\xc6\x06\xe0 \x9e\x85\x16?\xe7\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0', b"\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V\xc0R\xc0$\xc0(\x00k\x00j\xc0s\xc0w\x00\xc4\x00\xc3\xc0#\xc0'\x00g\x00@\xc0r\xc0v\x00\xbe\x00\xbd\xc0\n\xc0\x14\x009\x008\x00\x88\x00\x87\xc0\t\xc0\x13\x003\x002\x00\x9a\x00\x99\x00E\x00D\xc0\x07\xc0\x11\xc0\x08\xc0\x12\x00\x16\x00\x13\x00\x9d\xc0\xa1\xc0\x9d\xc0Q\x00\x9c\xc0\xa0\xc0\x9c\xc0P\x00=\x00\xc0\x00<\x00\xba\x005\x00\x84\x00/\x00\x96\x00A\x00\x05\x00\n\x00\xff\x01\x00\x00j\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#"]
Bad pipe message: %s [b'\xe7/\xd0\x9b\xc4\x98!']
Bad pipe message: %s [b'E\xa8\x82\xba_2\x87Y\x98\x00\x00']
Bad pipe message: %s [b'\x14\xc0\n\x009\x008\x007\x006\xc0\x0f\xc0\x05\x005\xc0\x13\xc0\t\x003\x002\x001\x000\xc0\x0e\xc0\x04\x00/\x00\x9a\x00\x99\x00\x98\x00\x97\x00\x96\x00\x07\xc0\x11\xc0\x07\xc0\x0c\xc0\x02\x00\x05\x00\x04\x00\xff']
Bad pipe message: %s [b"\x04h'\x1e\x85\x8c\xcdb\x84m)\xfd^\x02\xbb\x80T\xd6\x00\x00\xa2\xc0\x14\xc0\n\x009\x008\x007\x006\x00\x88\x00\x87\x00\x86\x00\x85\xc0\x19\x00:\x00\x89\xc0\x0f\xc0\x05\x005\x00\x84\xc0\x13\xc0\t\x003\x002\x001\x000\x00\x9a\x00\x99\x00\x98\x00\x97\x00E\x00D\x00C\x00B\xc0\x18\x004\x00\x9b\x00F\xc0\x0e\xc0\x04\x00/\x00\x96\x00A\x00\x07\xc0\x11\xc0\x07\xc0\x16\x00\x18\xc0\x0c\xc0\x02\x00\x05\x00\x04\xc0\x12\xc0\x08\x00\x16\x00\x13\x00\x10\x00\r\xc0\x17\x00\x1b\xc0\r\xc0\x03\x00\n\x00\x15\x00\x12\x00\x0f\x00\x0c\x00\x1a\x00\t\x00\x14\x00\x11\x00\x19\x00\x08\x00\x06\x00\x17\x00\x03\xc0\x10\xc0\x06\xc0\x15\xc0\x0b\xc0\x01\x00\x02\x00\x01\x00\xff\x02\x01\x00\x00C\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x1c\x00\x1a\x00\x17\x00\x19\x00\x1c\x00"]
Bad pipe message: %s [b'\x18\x00\x1a\x00\x16\x00\x0e\x00\r\x00\x0b\x00\x0c\x00\t\x00\n\x00#\x00\x00\x00\x0f\x00\x01\x01']
Bad pipe message: %s [b'\x01\xb6\xd1\x07\x00\xcbO\x95W\xec\xd6\x96l\xf1\xc8\xc1\xfaZ\x00\x00\xa2\xc0\x14\xc0\n\x009\x008\x007\x006\x00\x88\x00\x87\x00\x86\x00\x85\xc0\x19\x00:\x00\x89\xc0\x0f\xc0\x05\x005\x00\x84\xc0\x13\xc0\t\x003\x002\x001\x000\x00\x9a\x00\x99\x00\x98\x00\x97\x00E\x00D\x00C\x00B\xc0\x18\x004\x00\x9b\x00']
Bad pipe message: %s [b'\x0e\xc0\x04\x00/\x00\x96\x00A\x00\x07\xc0\x11\xc0\x07\xc0\x16\x00\x18\xc0\x0c\xc0\x02\x00\x05\x00\x04\xc0\x12\xc0\x08\x00\x16\x00\x13\x00\x10\x00\r\xc0\x17\x00\x1b\xc0\r\xc0\x03\x00\n\x00\x15\x00\x12\x00\x0f\x00\x0c\x00\x1a\x00\t\x00\x14\x00\x11\x00\x19\x00\x08']
Bad pipe message: %s [b'\xfe"\x98\xab\xdd\xec~\x10\'OD\x8eAI\x0b\xd0;\xf5\x00\x00\xa2\xc0\x14\xc0\n\x009\x008\x007\x006\x00\x88\x00\x87\x00\x86\x00\x85\xc0\x19\x00:\x00\x89\xc0\x0f\xc0\x05\x005\x00\x84\xc0\x13\xc0\t\x003\x002\x001\x000\x00\x9a\x00\x99\x00\x98\x00\x97\x00E\x00D\x00C\x00B\xc0\x18\x004\x00\x9b\x00F\xc0\x0e\xc0\x04\x00/\x00\x96\x00A\x00\x07\xc0\x11']
Bad pipe message: %s [b"Q\x95\\\x9f\x1a\x9e\x13\xb9\xe6h~\x9c\xa5\x80a\x0b\x11j\x00\x00\x86\xc00\xc0,\xc0(\xc0$\xc0\x14\xc0\n\x00\xa5\x00\xa3\x00\xa1\x00\x9f\x00k\x00j\x00i\x00h\x009\x008\x007\x006\xc02\xc0.\xc0*\xc0&\xc0\x0f\xc0\x05\x00\x9d\x00=\x005\xc0/\xc0+\xc0'\xc0#\xc0\x13\xc0\t\x00\xa4\x00\xa2\x00\xa0\x00\x9e\x00g\x00@\x00?\x00>\x003\x002\x001\x000\xc01\xc0-\xc0)\xc0%\xc0\x0e\xc0\x04\x00\x9c\x00<\x00/\x00\x9a\x00\x99\x00\x98\x00\x97\x00\x96\x00\x07\xc0\x11\xc0\x07\xc0\x0c\xc0\x02\x00\x05\x00\x04\x00\xff\x02\x01\x00\x00g\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x1c\x00\x1a\x00\x17\x00\x19\x00\x1c\x00\x1b\x00\x18\x00\x1a\x00\x16\x00\x0e\x00\r\x00\x0b\x00\x0c\x00\t\x00\n\x00#\x00\x00\x00\r\x00 \x00\x1e\x06\x01\x06\x02\x06\x03\x05\x01\x05\x02\x05\x03", b'\x04\x02\x04', b'\x01\x03', b'\x03', b'\x02', b'\x03']