1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
| import numpy as np import math import matplotlib.pyplot as plt
x = np.linspace(-math.pi, math.pi, 2000) y = np.sin(x)
a = np.random.randn() b = np.random.randn() c = np.random.randn() d = np.random.randn()
learning_rate = 1e-6 for t in range(2000): y_pred = a + b * x + c * x ** 2 + d * x ** 3
loss = np.square(y_pred - y).sum() if t % 100 == 99: print(t, loss)
grad_y_pred = 2.0*(y_pred - y) grad_a = grad_y_pred.sum() grad_b = (grad_y_pred * x).sum() grad_c = (grad_y_pred * x ** 2).sum() grad_d = (grad_y_pred * x ** 3).sum()
a -= learning_rate * grad_a b -= learning_rate * grad_b c -= learning_rate * grad_c d -= learning_rate * grad_d
print(f'Result: y = {a} + {b} x + {c} x^2 + {d} x^3')
fig = plt.figure() ax = fig.add_subplot(1,1,1) ax.plot(y) ax.plot(y_pred) plt.show() input()
|