Saturday, 20 January 2024

Approximate Functions

Approximate Functions

import torch
import matplotlib.pyplot as plt


class Function:

    def __init__(self, w=None, b=None):
        self.w = w or torch.rand(1, requires_grad=True)
        self.b = b or torch.rand(1, requires_grad=True)

    def __call__(self, x):
        return self.w * x**2 + self.b



class Approximator:

    def __init__(self, function):
        self.learning_rate = 0.02
        self.function = function

        if hasattr(self.function, 'criterion'):
            self.criterion = self.function.criterion
        else:
            self.criterion = lambda y, y_hat: ((y - y_hat) ** 2).mean()

    def approximate(self, x, y):

        for i in range(1000):
            y_hat = self.function(x)
            loss = self.criterion(y, y_hat)
            loss.backward()
            
            # TODO: automate that
            self.function.w.data = self.function.w.data - self.learning_rate * self.function.w.grad.data
            self.function.b.data = self.function.b.data - self.learning_rate * self.function.b.grad.data

            self.function.w.grad.data.zero_()
            self.function.b.grad.data.zero_()


    def __call__(self, x):
        return self.function(x)



###################

torch.manual_seed(0)

def generate_data(n):
    x = torch.rand(n, 1)
    function = Function(torch.tensor([4.0]), torch.tensor([0.0]))
    return x, function(x) + 0.1 * torch.randn(n, 1)


x, y = generate_data(100)

function = Function()
approx = Approximator(function)
approx.approximate(x, y)



plt.scatter(x.detach().numpy(), y.detach().numpy())
# validate:
x, _ = generate_data(100)
x = x * 2
plt.scatter(x, approx(x).detach().numpy(), c='r')
plt.show()

No comments:

Post a Comment

Parse Wikipedia dump

""" This module processes Wikipedia dump files by extracting individual articles and parsing them into a structured format, ...