Approximate Functions
import torch import matplotlib.pyplot as plt class Function: def __init__(self, w=None, b=None): self.w = w or torch.rand(1, requires_grad=True) self.b = b or torch.rand(1, requires_grad=True) def __call__(self, x): return self.w * x**2 + self.b class Approximator: def __init__(self, function): self.learning_rate = 0.02 self.function = function if hasattr(self.function, 'criterion'): self.criterion = self.function.criterion else: self.criterion = lambda y, y_hat: ((y - y_hat) ** 2).mean() def approximate(self, x, y): for i in range(1000): y_hat = self.function(x) loss = self.criterion(y, y_hat) loss.backward() # TODO: automate that self.function.w.data = self.function.w.data - self.learning_rate * self.function.w.grad.data self.function.b.data = self.function.b.data - self.learning_rate * self.function.b.grad.data self.function.w.grad.data.zero_() self.function.b.grad.data.zero_() def __call__(self, x): return self.function(x) ################### torch.manual_seed(0) def generate_data(n): x = torch.rand(n, 1) function = Function(torch.tensor([4.0]), torch.tensor([0.0])) return x, function(x) + 0.1 * torch.randn(n, 1) x, y = generate_data(100) function = Function() approx = Approximator(function) approx.approximate(x, y) plt.scatter(x.detach().numpy(), y.detach().numpy()) # validate: x, _ = generate_data(100) x = x * 2 plt.scatter(x, approx(x).detach().numpy(), c='r') plt.show()
No comments:
Post a Comment