Tuesday, 27 February 2024

SPLADE

Implementing SPLADE for Efficient Information Retrieval: A PyTorch Guide

In the rapidly evolving field of information retrieval, the quest for models that balance efficiency with effectiveness is ongoing. One of the standout models that has garnered attention for its novel approach is SPLADE - Sparse Lexical and Expansion Model. This model intriguingly combines the depth of neural networks with the efficiency of sparse representations, making it a topic worth exploring. Today, I'll walk you through a simplified example of implementing SPLADE using PyTorch, aiming for clarity and content richness over jargon.

The Essence of SPLADE

At its core, SPLADE leverages the power of transformer-based models, like BERT, to generate sparse vectors for text. These vectors are not only efficient for computation but also effective in capturing the nuanced semantics of language, a crucial aspect for information retrieval tasks. The magic of SPLADE lies in its training process, where it employs a sparsity-inducing loss function. This encourages the model to zero in on the most relevant tokens, leaving behind a trail of zeroes for the rest.

Building SPLADE in PyTorch

Let's dive into the practical side. Our journey begins with the implementation of SPLADE atop PyTorch, a popular deep learning library known for its flexibility and ease of use. Here's a distilled guide to bringing SPLADE to life in your projects.

The SPLADE Model

Our model starts with a familiar face - BERT, leveraging its pre-trained might to understand language. We then introduce a linear layer tasked with scoring the relevance of text, focusing on the essence of what makes or breaks a match in information retrieval.

import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer

class SPLADE(nn.Module):
    def __init__(self, pretrained_model_name='bert-base-uncased'):
        super(SPLADE, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        self.fc = nn.Linear(768, 1)  # A direct path to scoring relevance

Training with a Twist

Training SPLADE is where things get interesting. Our goal is not just to teach the model about relevance but also to embrace sparsity. This dual focus requires a careful balance, achieved through a specialized loss function.

def train_splade(model, data_loader, optimizer, device):
    model.train()
    for batch in data_loader:
        # Standard training loop, with a SPLADE-specific twist in the loss calculation
        scores = model(input_ids, attention_mask).squeeze()
        loss = compute_loss(scores, labels)  # A concoction of relevance and sparsity

def compute_loss(scores, labels):
    relevance_loss = nn.BCEWithLogitsLoss()(scores, labels)
    sparsity_loss = torch.norm(scores, p=1)  # L1 norm for sparsity
    return relevance_loss + 0.01 * sparsity_loss  # The delicate dance of balancing

The Takeaway

Implementing SPLADE from scratch in PyTorch illuminates the intriguing blend of neural depth and sparse efficiency. This example, while simplified, captures the essence of SPLADE's approach to information retrieval. It's a testament to the model's innovative leveraging of sparsity-inducing techniques, ensuring that every token counts in the vast sea of information.

Diving into SPLADE opens up new vistas for tackling the challenges of information retrieval, providing a framework that's both efficient and effective. As you embark on integrating SPLADE into your projects, remember that the true power of this model lies in its nuanced balance, a reminder of the intricate dance between precision and practicality in the digital age. 

Cross Encoders

A cross-encoder in the context of information retrieval is a type of model architecture designed to enhance the performance of retrieving the most relevant information from a large dataset based on a given query. Unlike traditional or simpler encoder models that process queries and documents separately, cross-encoders evaluate the relevance of a document to a query by jointly encoding both the query and the document together in a single pass. This approach allows the model to consider the intricate interactions between the query and the document, leading to a more nuanced understanding and often superior retrieval performance.

Key characteristics:

  • Joint Encoding: Cross-encoders take both the query and a candidate document as input and combine them into a single input sequence, often with a special separator token in between. This allows the model to directly learn interactions between the query and the document.
  • Fine-Grained Understanding: By considering the query and document together, cross-encoders can better capture the nuances of relevance, including context, semantic similarities, and specific details that might be missed when encoding them separately.
  • Computational Intensity: While providing high accuracy, cross-encoders are computationally more intensive than other architectures like bi-encoders, because each query-document pair must be processed together. This can make them less efficient for applications that require scanning through very large datasets.
  • Use in Ranking: Cross-encoders are particularly useful for the ranking stage of information retrieval, where a smaller subset of potentially relevant documents (pre-filtered by a more efficient method, like a bi-encoder) needs to be ranked accurately according to their relevance to the query.
  • Training and Fine-tuning: Cross-encoders can be trained or fine-tuned on specific tasks or datasets, allowing them to adapt to the nuances of different domains or types of information retrieval tasks.

In practice, cross-encoders are often used in combination with other models in a two-step retrieval and ranking process. An initial, more efficient model (such as a bi-encoder) quickly narrows down the search space to a manageable number of candidate documents, and then a cross-encoder is applied to this subset to accurately rank the documents in order of relevance to the query. This approach balances the need for both efficiency and high accuracy in information retrieval systems.


Pseudo code:
import torch
from transformers import BertTokenizer, BertForSequenceClassification

class CrossEncoder(torch.nn.Module):
    def __init__(self, pretrained_model_name='bert-base-uncased'):
        super(CrossEncoder, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name)
        # Assuming a binary classification model where 1 indicates relevance.
        self.model = BertForSequenceClassification.from_pretrained(pretrained_model_name, num_labels=2)

    def forward(self, query, document):
        # Tokenize query and document together, separating them with a [SEP] token
        inputs = self.tokenizer.encode_plus(query, document, return_tensors='pt', add_special_tokens=True, truncation=True, max_length=512)
        # Forward pass through the model
        outputs = self.model(**inputs)
        # Get the logits
        logits = outputs.logits
        # Apply softmax to get probabilities
        probabilities = torch.softmax(logits, dim=1)
        # Assuming label 1 corresponds to "relevant"
        relevance_probability = probabilities[:, 1]
        return relevance_probability

# Example usage
cross_encoder = CrossEncoder()

# Example query and document
query = "What is artificial intelligence?"
document = "Artificial intelligence is a branch of computer science that aims to create intelligent machines."

# Compute relevance score
relevance_score = cross_encoder(query, document)

print(f"Relevance Score: {relevance_score.item()}")


Tuesday, 20 February 2024

Download a Wikipedia Category

import wikipediaapi
import logging 
import json
from pathlib import Path
from slugify import slugify
import click

logger = logging.getLogger(__name__)

wiki_wiki = wikipediaapi.Wikipedia('parasort (tom-010@web.de)', 'en')

logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)

@click.command()
@click.argument('category', type=str)
@click.option('--out', type=click.Path(path_type=Path), default='out', help='Output directory')
def main(category:str, out:Path):
    download_wikipedia_category(category, out)


def save_page(page, path:Path, breadcumbs: list[str]):
    """Save the content of a Wikipedia page into a file."""

    logger.info(f'Saving page: {page.title}')

    path.mkdir(parents=True, exist_ok=True)
    target = path / f'{slugify(page.title)}.json'
    if target.exists():
        logger.info(f'Page already exists: {page.title}')
        return

    page_json = {
        'title': page.title,
        'summary': page.summary,
        'pageid': page.pageid,
        'breadcumbs': breadcumbs,
        'sections': []
    }
    for section in page.sections:
        page_json['sections'].append({
            'level': section.level,
            'title': section.title,
            'text': section.full_text()
        })
    
    content = page.text
    if content:
        with target.open('w', encoding='utf-8') as f:
            json.dump(page_json, f, ensure_ascii=False, indent=2)

def process_category(category, path:Path, breadcumbs: list[str] = None):
    """Process each category and its subcategories."""
    if breadcumbs is None:
        breadcumbs = []

    logger.info(f'Processing category: {category}')

    for c in category.categorymembers.values():
        if c.ns == wikipediaapi.Namespace.CATEGORY:
            # Create a directory for the subcategory
            category_name = c.title.replace('Category:', '')
            breadcumbs.append(category_name)
            sub_path = path / slugify(category_name)
            process_category(c, sub_path, breadcumbs)
        else:
            save_page(c, path, breadcumbs)

def download_wikipedia_category(category_name:str, out_dir:Path):
    """Download all pages from a Wikipedia category and its subcategories."""
    cat = wiki_wiki.page("Category:" + category_name)
    target = out_dir / slugify(category_name)
    process_category(cat, target)

if __name__ == '__main__':
    main()

Sunday, 18 February 2024

Strategy for my keyboard layout

  1. Write down all keys on my existing keyboard
  2. Build a realistic corpora from a lot of documents (manly own but also similar, + keylogger)
  3. Specify how many keys I have
  4. Count the occurrences of each key and sort them
  5. I have more keys than single characters. Search for patterns, like in sentence piece.

Saturday, 17 February 2024

Python: Batch an Interator

from itertools import islice

def batched(iterator, batch_size):
    """
    Takes an iterator and a batch-size, and returns an iterator that yields batches of the given size.
    
    Parameters:
    - iterator: An iterator from which to generate batches.
    - batch_size: An integer specifying the size of each batch.
    
    Yields:
    - Batches of elements from the input iterator, each batch being a list of elements up to the specified batch_size.
    """
    iterator = iter(iterator)  # Ensure it's an iterator
    while True:
        batch = list(islice(iterator, batch_size))
        if not batch:
            break
        yield batch

Usage:

for batch in batched(range(80), 9):
    print(batch)


d = {str(idx): idx for idx in range(10)}

for batch in batched(d.items(), 3):
    for key, value in batch:
        print(key, value)

Output

[0, 1, 2, 3, 4, 5, 6, 7, 8]
[9, 10, 11, 12, 13, 14, 15, 16, 17]
[18, 19, 20, 21, 22, 23, 24, 25, 26]
[27, 28, 29, 30, 31, 32, 33, 34, 35]
[36, 37, 38, 39, 40, 41, 42, 43, 44]
[45, 46, 47, 48, 49, 50, 51, 52, 53]
[54, 55, 56, 57, 58, 59, 60, 61, 62]
[63, 64, 65, 66, 67, 68, 69, 70, 71]
[72, 73, 74, 75, 76, 77, 78, 79]
0 0
1 1
2 2
3 3
4 4
5 5
6 6
7 7
8 8
9 9

Friday, 16 February 2024

Skill estimation and recommender systems

In a recommender system, you try to predict empty cells in the matrix between users and items.

In skill estimation you try to predict empty cell in the matrix between players and games.

Collaborative filtering also has a nice intuition.

pdict: Persistent dict (python)

from sqlitedict import SqliteDict
import sqlite3

def pdict(db_path:str|Path):
    if not isinstance(db_path, Path):
        db_path = Path(db_path)
    if not db_path.exists():
        db_path.parent.mkdir(exist_ok=True, parents=True)
        conn = sqlite3.connect(db_path)
        conn.execute('PRAGMA journal_mode=WAL;')
        conn.commit()
        conn.close()
    return SqliteDict(db_path, autocommit=True)

Wednesday, 7 February 2024

Generative Deep Learning, Chapter 1

 Challange

Create an image classifier with this architecture:

  • Fully Connected Layer, size=200
  • ReLU
  • Fully Connected Layer, size=150
  • ReLU
  • Fully Connected Layer, size=10 (output)
Train it on CIFAR10 dataset (lr=0.001) works. Evaluate it. Plot a view images with the predicted and actual label.


import numpy as np
import torchvision
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn, optim
from torch.nn import functional as F
import matplotlib.pyplot as plt



NUM_CLASSES = 10
BATCH_SIZE = 64


# load dataset

dataset = torchvision.datasets.CIFAR10('cifar10', download=True)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


# define the architecture

class MLP(nn.Module):

  def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(32*32*3, 200)
    self.fc2 = nn.Linear(200,150)
    self.fc3 = nn.Linear(150,10)
    self.relu = nn.ReLU()

  def forward(self, x):
    out = x.view(x.size(0), -1)
    out = self.fc1(out)
    out = self.relu(out)
    out = self.fc2(out)
    out = self.relu(out)
    out = self.fc3(out)
    return out
    
    
# train it

from tqdm import tqdm


model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10

for epoch in range(num_epochs):
  running_loss = 0.0
  for images, labels in tqdm(train_loader):
    # Flatten the images for the MLP
    images = images.view(images.size(0), -1)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    running_loss += loss.item()
  print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')
  
  
  
# evaluate it
 
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc='Evaluating'):
        images = images.view(images.size(0), -1)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}%') # 53.19%


# predict the classes

import numpy as np

classes = np.array(['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'])

model.eval()

preds_list = []
actuals_list = []
with torch.no_grad():
    for images, labels in test_loader:
        images = images.view(images.size(0), -1)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        preds_list.extend(preds.cpu().numpy())
        actuals_list.extend(labels.cpu().numpy())

preds_array = np.array(preds_list)
actuals_array = np.array(actuals_list)
preds_single = classes[preds_array]
actual_single = classes[actuals_array]


# visualize it

def get_image(idx, loader):
  for i, (images, _) in enumerate(loader):
    if i == idx // BATCH_SIZE:  # Find the batch containing the 54th image
        image = images[idx % BATCH_SIZE]  # Get the 54th image from the batch
        return image
  return None


n_to_show = 10
indices = np.random.choice(range(1000), n_to_show)

fig = plt.figure(figsize=(15, 3))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i, idx in enumerate(indices):
  img = get_image(idx, test_loader)
  ax = fig.add_subplot(1, n_to_show, i+1)
  ax.axis('off')
  ax.text(0.5, -0.35, 'pred='+str(preds_single[idx]), fontsize=10, ha='center', transform=ax.transAxes)
  ax.text(0.5, -0.7, 'actu='+str(actual_single[idx]), fontsize=10, ha='center', transform=ax.transAxes)
  img = img.numpy().transpose(1,2,0)
  img = img * 0.5 + 0.5
  img = np.clip(img, 0, 1)
  ax.imshow(img)

Notes:

transforms.Normalize((0.5,), (0.5,))

transform standardizes the pixel values of the images in the dataset.

Normalization typically changes the range of pixel intensity values. The Normalize transform does this by applying the following transformation to each channel of the image:

normalized_channel=channelmeanstd

For the CIFAR-10 dataset, images are in RGB format, meaning they have three channels (Red, Green, and Blue), each with pixel values in the range [0, 1] after applying transforms.ToTensor().

The Normalize transform here is called with (0.5,) for both the mean and std (standard deviation) parameters, but since CIFAR-10 images have three channels, and you provided a single value, it implicitly applies these values to all three channels.

The choice of (0.5, 0.5, 0.5) for both mean and standard deviation effectively shifts the input images' pixel value range from [0, 1] to [-1, 1] (after applying ToTensor() which scales images to [0, 1]). This is because subtracting 0.5 centers the pixel values around 0, and dividing by 0.5 scales them to a [-1, 1] range.

Operating in a [-1, 1] range can make the training process more stable and efficient for many models by ensuring that the inputs start in a more uniform and centered distribution. This is particularly beneficial for activation functions and optimization algorithms, making it easier to tune hyperparameters and achieve better performance.




Tuesday, 6 February 2024

Naive skill estimation

Naive Skill Estimation. As simple as it gets.



500 different seeds lead to this distribution:




Mean: 0.5333787274909965
Std: 0.15858765760754445




...................................

from dataclasses import dataclass
import random
import trueskill
from scipy.stats import spearmanr


# set seed 
random.seed(0)

def run(n_games):

    @dataclass
    class Player:
        name: str
        skill: int
        estimated_skill: trueskill.Rating = None

    n_players = 100
    skill_n = 100
    players = [Player(f"Player {i}", random.randint(1, skill_n)) for i in range(n_players)]
    ground_truth = [player for player in players]
    ground_truth.sort(key=lambda x: x.skill, reverse=True)


    players.sort(key=lambda x: x.skill, reverse=True)

    random.shuffle(players)


    @dataclass 
    class GameScore:
        player: Player
        score: float


    @dataclass 
    class Game:
        name: str
        players: list[Player]
        scores: list[GameScore]

    games:list[Game] = []

    for i in range(n_games):
        n_players_in_this_game = random.randint(2, n_players)
        game_players = random.sample(players, n_players_in_this_game)

        sigma = random.randint(1, int(skill_n * 0.2))

        scores = []
        for player in game_players:
            in_game_skill = player.skill + random.gauss(0, sigma)
            scores.append(GameScore(player, in_game_skill))

        scores.sort(key=lambda x: x.score, reverse=True)

        games.append(Game(f'Game {i}', game_players, scores))



    def get_score(ground_truth: list[Player], players: list[Player]) -> float:
        expected = [player.name for player in ground_truth]
        actual = [player.name for player in players]
        return spearmanr(expected, actual).correlation


    for player in players:
        player.estimated_skill = 0


    for game in games:

        for i, score in enumerate(game.scores):
            score.player.estimated_skill += len(game.scores) - i


    players.sort(key=lambda x: x.estimated_skill, reverse=True)


    return get_score(ground_truth, players)


from tqdm import tqdm
x = list(range(1, 2000, 5))
y = [run(n) for n in tqdm(x)]

import matplotlib.pyplot as plt
plt.plot(x, y)
plt.show()

Historgram created via:

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

results = []
for i in tqdm(range(500)):
    random.seed(i)
    results.append(run(1000))

results = np.array(results)

# histogram with 50 bins
plt.hist(results, bins=30)
plt.show()

print(f"Mean: {results.mean()}")
print(f"Std: {results.std()}")

Monday, 5 February 2024

Develop tooling for jupyter notebooks

  •  Create a setuptools project
  • create a folder where you place your notebook
  • start your notebook and do a `pip3 install -e ..`
  • Paste this into a cell:
    %load_ext autoreload
    %autoreload 2
    

Sunday, 4 February 2024

Intercept tqdm

For example to track the progress in a database or stream it via http

from tqdm import tqdm
import time

class TqdmContextManager:
    """
    A context manager to temporarily override the default tqdm class with a custom one.
    This allows for custom behavior of tqdm progress bars globally within its context.
    """
    def __init__(self):
        # Save a reference to the original tqdm class
        self.original_tqdm = tqdm

    def __enter__(self):
        # Override the global tqdm reference with our custom class upon entering the context
        global tqdm
        tqdm = CustomTqdm
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        # Restore the original tqdm class upon exiting the context
        global tqdm
        tqdm = self.original_tqdm


class CustomTqdm(tqdm):
    """
    A custom tqdm subclass that limits update frequency of progress output.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.last_print_time = time.time()  # Record the start time
        self.update_every = 0.1  # Limit to printing every 0.1 seconds (10 times per second)

    def update(self, n=1):
        """
        Override the update method to control the frequency of progress output.
        """
        current_time = time.time()
        if current_time - self.last_print_time >= self.update_every:
            # Print progress if enough time has passed since the last print
            print(f"Progress: {self.n}/{self.total}", end='\r', flush=True)
            self.last_print_time = current_time  # Update the last print time


# Example usage of the custom tqdm with a context manager
def some_function():
    """
    A sample function to demonstrate the use of the custom tqdm progress bar.
    """
    for i in tqdm(range(50)):
        time.sleep(0.05)  # Simulate work with a delay

# Use the custom tqdm progress bar within the context manager
with TqdmContextManager():
    some_function()

Parse Wikipedia dump

""" This module processes Wikipedia dump files by extracting individual articles and parsing them into a structured format, ...