Model Distillation Explained: How DeepSeek Leverages the Technique for AI Success

Model distillation, also known as knowledge distillation, is a supervised learning technique that condenses the capabilities and thought processes of a large, pre-trained “teacher” model into a smaller “student” model. This allows the student model to achieve comparable performance to the teacher model, but at a lower cost and with faster performance. Chinese AI lab DeepSeek has gained attention for using distillation to create AI models that are cheaper and more efficient than those developed by leading tech companies.

How Model Distillation Works

The primary goal of model distillation is for the smaller “student” model to not only match the teacher’s accuracy but also to emulate its actual thought processes, including its style, cognitive abilities, and alignment with human values. The process involves transferring knowledge from a large, complex model (the teacher) to a smaller, more efficient one (the student).

Types of Model Distillation

There are three main approaches to model distillation:

Offline Distillation

A pre-trained teacher model’s weights are frozen as it transfers knowledge to the student model. The teacher model remains unmodified while the student is trained to replicate its outputs.

Online Distillation

The student and teacher models are modified simultaneously in an end-to-end training process. The student learns to reflect changes in the teacher model in real-time, made possible through parallel processing.

Self-Distillation

The same network is used for both the teacher and the student model. The model learns from itself, transferring knowledge from the deeper layers to the shallower layers. This can help narrow the accuracy gap between teacher and student models.

Figure 1: Model distillation (simplified)

This is achieved through three main components:

Teacher Model

A large, pre-trained model that acts as an expert system. Training this model requires extensive computation, massive datasets, and sophisticated optimization techniques to capture complex patterns and relationships within the data. Soft labels, which reflect the model’s confidence in various possible outcomes, are generated during the teacher model’s training. These soft labels provide a valuable dataset for training the student model.

Let \(f_t(\mathbf{x})\) denote the teacher model, which outputs logits \(z_i(\mathbf{x})\) for input \(\mathbf{x}\).

To convert logits into soft labels (probabilities), we use the softmax function with a temperature parameter \(t\) that controls the smoothness of the output distribution. This equation helps soften the teacher’s output probabilities to provide richer information for the student model. The softmax output for class \(i\) from the teacher model is given by 1:

$$\displaystyle{y_i(\mathbf{x}∣t)= \frac{e^{\frac{z_{i}(\mathbf{x})}{t}}}{\sum_{j} e^{\frac{z_{j}(\mathbf{x})}{t}}}}$$

Student Model

A smaller model trained to mimic the behavior of the teacher model. The student model can vary in structure, from a simplified version of the teacher model to an entirely different network with an optimized structure.

Let \(f_s(\mathbf{x})\) denote the student model, which outputs logits \(z_s\) for the same input \(\mathbf{x}\). The logits are converted to soft labels – \(y_i^s\) . These are the student’s predictions for class \(i\).

Knowledge Transfer

The process of transferring knowledge from the teacher to the student involves training the student model on the soft labels generated by the teacher model, allowing it to learn the teacher’s decision-making patterns.

Distillation Loss Function

The primary objective of distillation is to minimize the difference between the student and teacher outputs. The distillation loss function can be defined as2:

$$\displaystyle{\mathcal{L}_{distill}= – \sum_{i} y_i^t(\mathbf{x})\; log⁡(y_i^s(\mathbf{x}))}$$

where \(y_i^t(\mathbf{x})\) is the softmax output of the teacher model at temperature \(t\) and \(y_i^s(\mathbf{x})\) is the softmax output of the student model at temperature \(t\). The loss captures how well the student mimics the teacher’s predictions.

Combined Loss Function

When ground truth labels are available, we can augment our loss function to include a term that measures how well the student predicts these labels. The combined loss function ensures that the student model learns from both the teacher’s knowledge (soft labels) and the ground truth (hard labels). It is the weighted combination of the two loss functions.

$$E(\mathbf{x}) = \alpha \cdot \mathcal{L}_{hard} + (1-\alpha) \cdot \mathcal{L}_{distill} $$

\(\mathcal{L}_{hard}\) is explicitly defined as the cross-entropy between the student’s predictions \(y_x\) at temperature \(t=1\) and the ground truth\(y\). This is the standard classification loss.

$$\mathcal{L}_{hard} = – \sum_i \bar{y}_i \log(y_i(\mathbf{x}|1)) $$

\(\mathcal{L}_{distill}\) is defined as KL Divergence (or another suitable divergence measure) between the student’s and teacher’s soft labels. This is the distillation loss component. KL divergence is a common choice among several available techniques for computing the distillation loss 3 .

$$\mathcal{L}_{distill} = \sum_i y_i^t(\mathbf{x}) \log{\frac{y_i^t(\mathbf{x})}{y_i^s(\mathbf{x})}}$$

The \(\alpha\) parameter controls the relative importance of these two sources of information. A higher \(\alpha\) gives more weight to the ground truth, while a lower \(\alpha\) gives more weight to the teacher’s soft labels. In essence, the combined loss function is

$$E(\mathbf{x}) = -\alpha \cdot \sum_i \bar{y}_i \log(y_i(\mathbf{x}|1)) + (1 – \alpha) \cdot \sum_i y_i^t(\mathbf{x}) \log{\frac{y_i^t(\mathbf{x})}{y_i^s(\mathbf{x})}} $$

where, represents true labels and the first term in the loss function encourages the student to match the teacher’s predictions, while the second term ensures that it also learns from true labels.

Model distillation (detailed) knowledge distillation distillation loss combined loss hard loss ground truth
Figure 2: Model distillation (detailed)

Optimization Objective

The optimization objective during training is to minimize this combined loss function over a dataset:

$$\min_{\theta_s} E(\mathbf{x})$$

where \(\theta_s\) are the parameters of the student model.

Other flavours in knowledge transfer

In addition to matching outputs, knowledge transfer can occur through various mechanisms: (1) Response-Based Distillation that focuses on matching final output probabilities, (2) Feature-Based Distillation that Involves minimizing differences between intermediate feature representations of both models, and (3) Relation-Based Distillation which captures relationships between inputs and outputs, often requiring more sophisticated techniques.

Benefits of Model Distillation

Model distillation offers several advantages:

It requires less data for pre-training, maximizing the utility of data.

Smaller, more efficient models are ideal for deployment on platforms with limited resources, such as edge computing and real-time processing.

It can improve accuracy and performance on specific tasks, achieving similar performance to larger models with quicker response times.

Python code

Following python script illustrates the concept of model distillation in PyTorch where a smaller student model is trained to mimic the behavior of a larger, pre-trained teacher model. The SimpleModel class defines a simple feedforward neural network with one linear layer for demonstration purposes. Training is done on a subset of MNIST dataset (first 1000 samples).

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

def model_distillation(teacher_model, student_model, train_loader, optimizer, temperature, alpha):
    """
    Implements model distillation to train a student model from a teacher model.

    Args:
        teacher_model (nn.Module): Pre-trained teacher model.
        student_model (nn.Module): Student model to be trained.
        train_loader (DataLoader): DataLoader for the training dataset.
        optimizer (torch.optim): Optimizer for the student model.
        temperature (float): Temperature for softening the softmax outputs.
        alpha (float): Weighting factor between the hard label loss and the distillation loss.
    """

    teacher_model.eval()  # Set teacher model to evaluation mode
    student_model.train() # Set student model to training mode

    for images, labels in train_loader:
        #images = images.to(device) # uncomment if using GPU
        #labels = labels.to(device) # uncomment if using GPU

        optimizer.zero_grad()

        # Get teacher's soft labels
        with torch.no_grad():
            teacher_outputs = teacher_model(images)
            soft_labels = F.softmax(teacher_outputs / temperature, dim=1)

        # Get student's predictions
        student_outputs = student_model(images)

        # Calculate hard label loss (cross-entropy with true labels)
        hard_loss = F.cross_entropy(student_outputs, labels)

        # Calculate distillation loss (KL divergence between softened student outputs and soft labels)
        student_soft_outputs = F.log_softmax(student_outputs / temperature, dim=1)
        distillation_loss = F.kl_div(student_soft_outputs, soft_labels, reduction='batchmean') * (temperature ** 2)

        # Combine losses
        loss = alpha * hard_loss + (1 - alpha) * distillation_loss

        # Backpropagation and optimization
        loss.backward()
        optimizer.step()


if __name__ == '__main__':
    # Example usage:

    # 1. Define a simple teacher and student model (replace with your actual models)
    class SimpleModel(nn.Module):
        def __init__(self):
            super(SimpleModel, self).__init__()
            self.linear = nn.Linear(784, 10)  # Example: MNIST-like input (28x28) to 10 classes

        def forward(self, x):
            x = x.view(x.size(0), -1)  # Flatten the input
            x = self.linear(x)
            return x

    teacher_model = SimpleModel()
    student_model = SimpleModel()

    # 2. Create dummy data and a DataLoader (replace with your actual dataset)
    import torchvision
    import torchvision.transforms as transforms
    from torch.utils.data import DataLoader

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # Example transform for MNIST

    # Use a smaller subset of MNIST for demonstration
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    indices = torch.arange(0, 1000)  # Use the first 1000 samples
    subset = torch.utils.data.Subset(trainset, indices)
    train_loader = DataLoader(subset, batch_size=64, shuffle=True)

    # 3. Define an optimizer
    optimizer = optim.Adam(student_model.parameters(), lr=0.001)

    # 4. Set hyperparameters
    temperature = 5.0
    alpha = 0.5

    # 5. Train the student model using distillation
    model_distillation(teacher_model, student_model, train_loader, optimizer, temperature, alpha)

    print("Distillation complete!")

DeepSeek’s Success with Model Distillation

DeepSeek has leveraged model distillation to create AI models that are both cheaper and more efficient. By using this technique, DeepSeek and other smaller teams can train specialized models by having them learn from larger “teacher” models. This allows them to achieve performance levels close to those of models developed by leading tech companies, but with significantly reduced training time and resources. In one instance, researchers at Berkeley recreated OpenAI’s reasoning model for $450 in 19 hours using distillation4. Similarly, researchers at Stanford and the University of Washington created their own reasoning model in just 26 minutes, using less than $50 in compute credits4. This highlights the potential of distillation to democratize AI development, enabling less-capitalized startups and research labs to compete at the cutting edge4.

References:

  1. http://arxiv.org/abs/1503.02531
  2. https://arxiv.org/abs/2403.09053
  3. https://arxiv.org/abs/2405.16852
  4. https://www.nbcnewyork.com/news/business/money-report/how-deepseek-used-distillation-to-train-its-artificial-intelligence-model-and-what-it-means-for-companies-such-as-openai/6158779/?os=iosno_journeystruegpbfyoah
  5. Articles on Machine learning: https://www.gaussianwaves.com/category/machine-learning/

Books by author

Wireless Communication Systems in Matlab
Wireless Communication Systems in Matlab
Second Edition(PDF)

Note: There is a rating embedded within this post, please visit this post to rate it.
Digital modulations using Python
Digital Modulations using Python
(PDF ebook)

Note: There is a rating embedded within this post, please visit this post to rate it.
digital_modulations_using_matlab_book_cover
Digital Modulations using Matlab
(PDF ebook)

Note: There is a rating embedded within this post, please visit this post to rate it.
Hand-picked Best books on Communication Engineering
Best books on Signal Processing

Post your valuable comments !!!