Demystifying PyTorch Static Quantization

Tech
A deep dive into how PyTorch performs inference with quantized models.
Author

Haruka Doyu

Published

August 12, 2024

Modified

August 29, 2024

In the world of machine learning, optimizing model performance and efficiency is crucial, especially for deploying models on edge devices with limited resources. One powerful technique to achieve this is quantization, which reduces the precision of the numbers used in a model’s computations. PyTorch supports two types of quantization: dynamic and static. Dynamic quantization adjusts the precision of weights at runtime, while static quantization involves converting the model’s weights and activations to lower precision based on calibration data. This article will focus on statically quantized models, breaking down the core concepts and steps involved in PyTorch’s approach to inference with these models.

Note: This article assumes you are already familiar with quantization, particularly static quantization. If not, I recommend checking out the some materials, e.g., our technology page for an introduction.

from fastai.vision.all import *

import torch
from torch.ao.quantization import get_default_qconfig_mapping
import torch.ao.quantization.quantize_fx as quantize_fx
from torch.ao.quantization.quantize_fx import convert_fx, prepare_fx

Let’s start by creating a Quantizer class to quantize a PyTorch model. For an introduction to PyTorch quantization, you can refer to the official documentation. As an example, I will use the Imagenette2-320 dataset and the ResNet18 model. For convenience, I will leverage the Fastai learner to streamline this process.

class Quantizer():
    def __init__(self, backend="x86"):
        self.qconfig = get_default_qconfig_mapping(backend)
        torch.backends.quantized.engine = backend

    def quantize(self, model, calibration_dls):
        x, _ = calibration_dls.valid.one_batch()
        model_prepared = prepare_fx(model.eval(), self.qconfig, x)
        with torch.no_grad():
            _ = [model_prepared(xb.to('cpu')) for xb, _ in calibration_dls.valid]

        return model_prepared, convert_fx(model_prepared)
path = untar_data(URLs.IMAGENETTE_320, data=Path.cwd()/'data')
dls = ImageDataLoaders.from_folder(path, valid='val', item_tfms=Resize(224),
                                   batch_tfms=Normalize.from_stats(*imagenet_stats))
learn = vision_learner(dls, resnet18)
model_prepared, qmodel = Quantizer("qnnpack").quantize(learn.model, learn.dls)

In static quantization, the scaling factors and zero points for weights and activations are determined after model calibration but before inference. In this context, we are using per-tensor quantization, which means that there is a single scaling factor and zero point applied uniformly across all elements in each tensor of a layer. This approach is straightforward and computationally efficient, as it simplifies the quantization process by treating the entire tensor as a whole.

In the above cell, model_prepared instance represents the model after it has recorded the range of activations across a validation dataset. This model contains the necessary information about the model structure and activation ranges, from which the scaling factors and zero points are calculated. Below is an example of the quantization parameters for some activations. The HistogramObserver is used to record the activation ranges. The first output shows the quantized parameters of the first activation, which is the model input, while the second output shows the quantization parameters of the second activation, which is the output of the first Conv2d + ReLU layer. In PyTorch, to avoid redundant quantization and dequantization processes between layers, batch normalization is folded into the preceding layer (batch normalization folding), and the ReLU layer is fused with the layer it follows.

# Example activation quantization parameters
for i in range(3):
    attr = getattr(model_prepared, f"activation_post_process_{i}")
    scale, zero_p = attr.calculate_qparams()
    print("{}\nScaling Factor: {}\nZero Point: {}\n".format(attr, scale.item(), zero_p.item()))
HistogramObserver(min_val=-2.1179039478302, max_val=2.640000104904175)
Scaling Factor: 0.018649335950613022
Zero Point: 114

HistogramObserver(min_val=0.0, max_val=7.000605583190918)
Scaling Factor: 0.011327190324664116
Zero Point: 0

HistogramObserver(min_val=0.0, max_val=7.000605583190918)
Scaling Factor: 0.011327190324664116
Zero Point: 0

qmodel instance represents the quantized model. It contains quantized weights, along with their associated scaling factor and zero point, as well as the scaling factor and zero point for activations. Additionally, it includes some non-quantized parameters, which I will explain later.

qmodel
GraphModule(
  (0): Module(
    (0): QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.011327190324664116, zero_point=0, padding=(3, 3))
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Module(
      (0): Module(
        (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.008901300840079784, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.024013830348849297, zero_point=149, padding=(1, 1))
      )
      (1): Module(
        (conv1): QuantizedConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.007031331304460764, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.031252723187208176, zero_point=156, padding=(1, 1))
      )
    )
    (5): Module(
      (0): Module(
        (conv1): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(2, 2), scale=0.007301042787730694, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.019116230309009552, zero_point=124, padding=(1, 1))
        (downsample): Module(
          (0): QuantizedConv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), scale=0.01664934679865837, zero_point=135)
        )
      )
      (1): Module(
        (conv1): QuantizedConvReLU2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.008282394148409367, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.02566305175423622, zero_point=137, padding=(1, 1))
      )
    )
    (6): Module(
      (0): Module(
        (conv1): QuantizedConvReLU2d(128, 256, kernel_size=(3, 3), stride=(2, 2), scale=0.010484358295798302, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.02675902470946312, zero_point=90, padding=(1, 1))
        (downsample): Module(
          (0): QuantizedConv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), scale=0.008271278813481331, zero_point=162)
        )
      )
      (1): Module(
        (conv1): QuantizedConvReLU2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.00832998938858509, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), scale=0.027811763808131218, zero_point=142, padding=(1, 1))
      )
    )
    (7): Module(
      (0): Module(
        (conv1): QuantizedConvReLU2d(256, 512, kernel_size=(3, 3), stride=(2, 2), scale=0.006999513134360313, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.023119885474443436, zero_point=140, padding=(1, 1))
        (downsample): Module(
          (0): QuantizedConv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), scale=0.02033478580415249, zero_point=128)
        )
      )
      (1): Module(
        (conv1): QuantizedConvReLU2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.006345659960061312, zero_point=0, padding=(1, 1))
        (conv2): QuantizedConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), scale=0.12105856835842133, zero_point=88, padding=(1, 1))
      )
    )
  )
  (1): Module(
    (0): Module(
      (mp): AdaptiveMaxPool2d(output_size=1)
      (ap): AdaptiveAvgPool2d(output_size=1)
    )
    (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): QuantizedDropout(p=0.25, inplace=False)
    (4): QuantizedLinearReLU(in_features=1024, out_features=512, scale=0.08005672693252563, zero_point=0, qscheme=torch.per_tensor_affine)
    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): QuantizedDropout(p=0.5, inplace=False)
    (8): QuantizedLinear(in_features=512, out_features=10, scale=0.10456003248691559, zero_point=150, qscheme=torch.per_tensor_affine)
  )
)

Let’s investigate the first layer of qmodel, i.e., quantized Conv2d + ReLU layer.

layer = qmodel._modules['0']._modules['0']
print(layer)
print("Weight Scale: {}, Weight Zero Point: {}".format(layer.weight().q_scale(),
                                                       layer.weight().q_zero_point()))
print("Output Scaling Factor: {}, Output Zero Point: {}\n".format(layer.scale, 
                                                                  layer.zero_point))

print("Example weights:", layer.weight()[0, 0, 0])
print("In integer representation:", layer.weight()[0, 0, 0].int_repr())
QuantizedConvReLU2d(3, 64, kernel_size=(7, 7), stride=(2, 2), scale=0.011327190324664116, zero_point=0, padding=(3, 3))
Weight Scale: 0.0030892190989106894, Weight Zero Point: 0
Output Scaling Factor: 0.011327190324664116, Output Zero Point: 0

Example weights: tensor([-0.0031,  0.0000,  0.0000,  0.0185,  0.0124,  0.0031, -0.0031],
       size=(7,), dtype=torch.qint8, quantization_scheme=torch.per_tensor_affine,
       scale=0.0030892190989106894, zero_point=0)
In integer representation: tensor([-1,  0,  0,  6,  4,  1, -1], dtype=torch.int8)

As shown above, the quantized layer contains two scaling factors and zero points: one for the weights and another for the output activation. You may have noticed that the output scaling factor and zero point are the same as those displayed in the second cell above, as they represent the same activation.

What about biases, which I haven’t discussed yet? In PyTorch, bias quantization depends on the backend you use. For example, if you specify the x86 backend, biases are not quantized and are instead used as floating-point values. On the other hand, the QNNPACK backend quantizes biases. However, biases are not quantized during the initial quantization stage; they are quantized at inference time. Thus, even the inference uses quantized biases, PyTorch does not display the quantized biases at before inference. The formula for bias quantization in PyTorch is: \[ b_q = round(b / (si * sw)) \] , where \(b_q\) is quantized bias, \(b\) is bias before quantization, \(si\) is input activation scale and \(sw\) is weight scale. For more details, you can refer to this discussion.

In addition, the model may include other non-quantized parameters, such as parameters in batch normalization layers that are not fused. This is likely because quantizing the activations in these layers would not provide significant benefits.

What happens during inference?

This section demonstrates how calculations are performed in the quantized model during inference. To illustrate this, I calculate the output of the first convolutional layer and validate it against the actual result.

layer_input = None
layer_output = None

def hook_fn(module, input, output):
    global layer_output, layer_input
    layer_input = input
    layer_output = output

img = torch.rand([1, 3, 224, 224])
hook = qmodel._modules['0']._modules['0'].register_forward_hook(hook_fn)
output = qmodel(img)
hook.remove()
print("Example input:", layer_input[0][0,0,0,:10].int_repr())
print("Example output:", layer_output[0,0,0,:10].int_repr())
Example input: tensor([163, 119, 155, 138, 126, 164, 115, 132, 115, 166], dtype=torch.uint8)
Example output: tensor([10,  0,  0,  0,  0,  0,  0,  0,  0,  0], dtype=torch.uint8)
import numpy as np

def quantize(x, qparams, itype):
    xtype = torch.iinfo(itype)
    return torch.clamp(torch.round(x / qparams[0]) + qparams[1], min=xtype.min, max=xtype.max)

def dequantize(x, qparams):
    return (x - qparams[1]) * qparams[0]

def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
    N, C, H, W = input_data.shape
    out_h = (H + 2 * pad - filter_h) // stride + 1
    out_w = (W + 2 * pad - filter_w) // stride + 1

    img = np.pad(input_data, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

    for y in range(filter_h):
        y_max = y + stride * out_h
        for x in range(filter_w):
            x_max = x + stride * out_w
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N * out_h * out_w, -1)
    return torch.tensor(col)

# first use im2col, which is efficient way to perform Conv2d operation
inp = im2col(img, 7, 7, 2, 3).float()
# quantize input values using input scale and zero point
inp = quantize(inp, [layer_input[0].q_scale(), layer_input[0].q_zero_point()], torch.uint8)
# get quantized weights, weight scale and quantize biases
w = qmodel._modules['0']._modules['0'].weight().int_repr().reshape(64, -1).float()
sw = qmodel._modules['0']._modules['0'].weight().q_scale()
b = quantize(qmodel._modules['0']._modules['0'].bias(),
             [layer_input[0].q_scale() * sw, 0], torch.int32)
b = b.reshape(1,64,1,1).detach()
# calculate matmul in Conv2d and add biases
out = (w @ (inp.T - layer_input[0].q_zero_point())).view(1,64,112,112) + b
# dequantize, perform ReLU and quantize based on output scale and zero point
out = out * sw * layer_input[0].q_scale()
out = torch.relu(out)
out = quantize(out, [layer_output.q_scale(), layer_output.q_zero_point()], torch.uint8)
torch.allclose(out, layer_output.int_repr().float())
print("Output: ", out[0, 0, 0, :10])
Output:  tensor([10.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])

Our calculation matches the actual result, which is a good sign. Although some operations in PyTorch’s implementation might be performed in a different order, the overall process is likely very similar.