Approximate Functions
import torch
import matplotlib.pyplot as plt
class Function:
def __init__(self, w=None, b=None):
self.w = w or torch.rand(1, requires_grad=True)
self.b = b or torch.rand(1, requires_grad=True)
def __call__(self, x):
return self.w * x**2 + self.b
class Approximator:
def __init__(self, function):
self.learning_rate = 0.02
self.function = function
if hasattr(self.function, 'criterion'):
self.criterion = self.function.criterion
else:
self.criterion = lambda y, y_hat: ((y - y_hat) ** 2).mean()
def approximate(self, x, y):
for i in range(1000):
y_hat = self.function(x)
loss = self.criterion(y, y_hat)
loss.backward()
# TODO: automate that
self.function.w.data = self.function.w.data - self.learning_rate * self.function.w.grad.data
self.function.b.data = self.function.b.data - self.learning_rate * self.function.b.grad.data
self.function.w.grad.data.zero_()
self.function.b.grad.data.zero_()
def __call__(self, x):
return self.function(x)
###################
torch.manual_seed(0)
def generate_data(n):
x = torch.rand(n, 1)
function = Function(torch.tensor([4.0]), torch.tensor([0.0]))
return x, function(x) + 0.1 * torch.randn(n, 1)
x, y = generate_data(100)
function = Function()
approx = Approximator(function)
approx.approximate(x, y)
plt.scatter(x.detach().numpy(), y.detach().numpy())
# validate:
x, _ = generate_data(100)
x = x * 2
plt.scatter(x, approx(x).detach().numpy(), c='r')
plt.show()
No comments:
Post a Comment