”工欲善其事,必先利其器。“—孔子《论语.录灵公》
首页 > 编程 > 构建常规等变 CNN 的原则

构建常规等变 CNN 的原则

发布于2024-07-31
浏览:947

The one principle is simply stated as 'Let the kernel rotate' and we will focus in this article on how you can apply it in your architectures.

Equivariant architectures allow us to train models which are indifferent to certain group actions.

To understand what this exactly means, let us train this simple CNN model on the MNIST dataset (a dataset of handwritten digits from 0-9).

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

Table 1: Test accuracy of the SimpleCNN model

As expected, we get over 95% accuracy on the testing dataset, but what if we rotate the image by 90 degrees? Without any countermeasures applied, the results drop to just slightly better than guessing. This model would be useless for general applications.

In contrast, let us train a similar equivariant architecture with the same number of parameters, where the group actions are exactly the 90-degree rotations.

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

Table 2: Test accuracy of the EqCNN model with the same amount of parameters as the SimpleCNN model

The accuracy remains the same, and we did not even opt for data augmentation.

These models become even more impressive with 3D data, but we will stick with this example to explore the core idea.

In case you want to test it out for yourself, you can access all code written in both PyTorch and JAX for free under Github-Repo, and training with Docker or Podman is possible with just two commands.

Have fun!

So What is Equivariance?

Equivariant architectures guarantee stability of features under certain group actions. Groups are simple structures where group elements can be combined, reversed, or do nothing.

You can look up the formal definition on Wikipedia if you are interested.

For our purposes, you can think of a group of 90-degree rotations acting on square images. We can rotate an image by 90, 180, 270, or 360 degrees. To reverse the action, we apply a 270, 180, 90, or 0-degree rotation respectively. It is straightforward to see that we can combine, reverse, or do nothing with the group denoted as C4C_4C4 . The image visualizes all actions on an image.

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
Figure 3: Feature maps for all four rotations after the input image was rotated

I color-coded the corresponding maps. Each feature map is shifted by one. As the final max operator computes the same result for these shifted feature maps, we obtain the same results.

In my code, I did not rotate back after the final convolution, since my kernels condense the image to a one-dimensional array. If you want to expand on this example, you would need to account for this fact.

Accounting for group actions or "kernel rotations" plays a vital role in the design of more sophisticated architectures.

Is it a Free Lunch?

No, we pay in computational speed, inductive bias, and a more complex implementation.

The latter point is somewhat solved with libraries such as E3NN, where most of the heavy math is abstracted away. Nevertheless, one needs to account for a lot during architecture design.

One superficial weakness is the 4x computational cost for computing all rotated feature layers. However, modern hardware with mass parallelization can easily counteract this load. In contrast, training a simple CNN with data augmentation would easily exceed 10x in training time. This gets even worse for 3D rotations where data augmentation would require about 500x the training amount to compensate for all possible rotations.

Overall, equivariance model design is more often than not a price worth paying if one wants stable features.

What is Next?

Equivariant model designs have exploded in recent years, and in this article, we barely scratched the surface. In fact, we did not even exploit the full C4C_4C4 group yet. We could have used full 3D kernels. However, our model already achieves over 95% accuracy, so there is little reason to go further with this example.

Besides CNNs, researchers have successfully translated these principles to continuous groups, including SO(2)SO(2)SO(2) (the group of all rotations in the plane) and SE(3)SE(3)SE(3) (the group of all translations and rotations in 3D space).

In my experience, these models are absolutely mind-blowing and achieve performance, when trained from scratch, comparable to the performance of foundation models trained on multiple times larger datasets.

Let me know if you want me to write more on this topic.

Further References

In case you want a formal introduction to this topic, here is an excellent compilation of papers, covering the complete history of equivariance in Machine Learning.
AEN

I actually plan to create a deep-dive, hands-on tutorial on this topic. You can already sign up for my mailing list, and I will provide you with free versions over time, along with a direct channel for feedback and Q&A.

See you around :)

版本声明 本文转载于:https://dev.to/freiberg-roman/the-1-principle-to-build-regular-equivariant-cnns-338b?1如有侵犯,请联系[email protected]删除
最新教程 更多>
  • 为什么自动换行在 HTML 表格中不起作用,如何修复?
    为什么自动换行在 HTML 表格中不起作用,如何修复?
    HTML 表格中的自动换行:修复未换行的文本自动换行是一个 CSS 属性,用于使文本在元素内换行就像 div 和 span 一样。但是,它通常无法在表格单元格中工作,导致文本溢出单元格的边界。要解决此问题,您可以使用表格的 table-layout:fixed CSS 属性。此属性强制表格具有固定布...
    编程 发布于2024-12-23
  • 什么时候应该在 C++ 中使用 `std::size_t` 作为循环计数器?
    什么时候应该在 C++ 中使用 `std::size_t` 作为循环计数器?
    何时在 C 代码中使用 std::size_t问题:何时使用 C 中的循环,特别是在比较数组大小的情况下,最好使用 std::size_t 而不是像这样的原始数据类型int?示例:#include <cstdint> int main() { for (std::size_t i...
    编程 发布于2024-12-23
  • 我应该对网站上的图像使用 Base64 编码吗?
    我应该对网站上的图像使用 Base64 编码吗?
    了解以 Base64 编码图像的影响将图像转换为 Base64 编码是 Web 开发中的常见做法。然而,重要的是要了解它对文件大小和网站性能的影响。Base64 编码图像的大小增加当图像转换为 Base64 时,其大小通常会增加约 37%。这是因为Base64编码使用6位字符集来表示8位数据,导致大...
    编程 发布于2024-12-23
  • 插入数据时如何修复“常规错误:2006 MySQL 服务器已消失”?
    插入数据时如何修复“常规错误:2006 MySQL 服务器已消失”?
    插入记录时如何解决“一般错误:2006 MySQL 服务器已消失”介绍:将数据插入 MySQL 数据库有时会导致错误“一般错误:2006 MySQL 服务器已消失”。当与服务器的连接丢失时会出现此错误,通常是由于 MySQL 配置中的两个变量之一所致。解决方案:解决此错误的关键是调整wait_tim...
    编程 发布于2024-12-23
  • HTML 格式标签
    HTML 格式标签
    HTML 格式化元素 **HTML Formatting is a process of formatting text for better look and feel. HTML provides us ability to format text without us...
    编程 发布于2024-12-23
  • 如何在 PHP 中组合两个关联数组,同时保留唯一 ID 并处理重复名称?
    如何在 PHP 中组合两个关联数组,同时保留唯一 ID 并处理重复名称?
    在 PHP 中组合关联数组在 PHP 中,将两个关联数组组合成一个数组是一项常见任务。考虑以下请求:问题描述:提供的代码定义了两个关联数组,$array1 和 $array2。目标是创建一个新数组 $array3,它合并两个数组中的所有键值对。 此外,提供的数组具有唯一的 ID,而名称可能重合。要求...
    编程 发布于2024-12-23
  • 考虑到版本特定的行为,如何正确地将 Java 数组转换为列表?
    考虑到版本特定的行为,如何正确地将 Java 数组转换为列表?
    在 Java 中将数组转换为列表:数组和列表转换之旅在数据操作领域,数组和列表之间的转换列表是 Java 等编程语言中的基础操作。然而,这种转换的复杂性可能会带来挑战,特别是由于 Java 版本之间行为的微妙变化。Arrays.asList() 行为的演变The Arrays.asList() 方法...
    编程 发布于2024-12-23
  • 为什么 Python 会抛出 UnboundLocalError?
    为什么 Python 会抛出 UnboundLocalError?
    UnboundLocalError 是如何发生的:Python 中的未绑定名称和变量绑定在 Python 中,变量绑定决定了变量的作用域和生命周期。当名称未分配值时,它被视为未绑定。这可能会导致 UnboundLocalError 异常。了解未绑定局部变量与具有显式声明的语言不同,Python 允许...
    编程 发布于2024-12-23
  • 通过“jQuery 快速入门”课程释放您的 Web 开发技能
    通过“jQuery 快速入门”课程释放您的 Web 开发技能
    您准备好提升您的 Web 开发专业知识并释放最流行的 JavaScript 库 jQuery 的强大功能了吗? LabEx 提供的“jQuery 快速入门”课程就是您的最佳选择。这个综合性的程序将引导您了解 jQuery 的基础知识,使您能够操作文档对象模型 (DOM) 并为您的网页注入迷人的交互性...
    编程 发布于2024-12-23
  • 如何在 MySQL WHERE IN() 子句中处理具有多个值的记录?
    如何在 MySQL WHERE IN() 子句中处理具有多个值的记录?
    MySQL IN () 运算符查询 MySQL 数据库时,WHERE IN () 运算符常用于根据特定条件检索行列中的值。例如,以下查询从“table”表中检索“id”列与任意值 (1、2、3、4) 匹配的所有行:SELECT * FROM table WHERE id IN (1,2,3,4);但...
    编程 发布于2024-12-23
  • 如何根据与特定值匹配的列值过滤数组行?
    如何根据与特定值匹配的列值过滤数组行?
    基于列值包含的行子集考虑一个具有多个列的数组 $arr1 和第二个平面数组 $arr2,包含特定的 id 值。目标是过滤 $arr1 以仅保留列值与 $arr2 中的任何值匹配的行。以前使用过滤函数或 array_search 的尝试已证明不成功。一个实用的解决方案涉及使用本机 PHP 函数 arr...
    编程 发布于2024-12-23
  • 如何使用 DockerMake 将多个 Docker 映像合并为一个映像?
    如何使用 DockerMake 将多个 Docker 映像合并为一个映像?
    组合多个 Docker 镜像Docker 不直接支持将多个 Docker 镜像组合成一个统一的镜像。但是,可以使用第三方工具来促进此过程。 DockerMake 就是这样一种工具,可以创建复杂的图像继承场景。使用 DockerMake 组合图像DockerMake 通过使用 YAML 文件来定义之间...
    编程 发布于2024-12-23
  • 大批
    大批
    方法是可以在对象上调用的 fns 数组是对象,因此它们在 JS 中也有方法。 slice(begin):将数组的一部分提取到新数组中,而不改变原始数组。 let arr = ['a','b','c','d','e']; // Usecase: Extract till index p...
    编程 发布于2024-12-23
  • OpenCV 的 `cvWaitKey()` 函数如何管理用户交互和窗口事件?
    OpenCV 的 `cvWaitKey()` 函数如何管理用户交互和窗口事件?
    探索 OpenCV 的“cvWaitKey()”功能OpenCV 的“cvWaitKey()”函数在管理用户交互方面发挥着至关重要的作用OpenCV 窗口。让我们深入研究其内部工作原理和典型用例:功能概述cvWaitKey(x) 提供两个主要功能:击键检测:它等待用户在 OpenCV 窗口上按指定的...
    编程 发布于2024-12-23
  • Bootstrap 4 Beta 中的列偏移发生了什么?
    Bootstrap 4 Beta 中的列偏移发生了什么?
    Bootstrap 4 Beta:列偏移的删除和恢复Bootstrap 4 在其 Beta 1 版本中引入了重大更改柱子偏移了。然而,随着 Beta 2 的后续发布,这些变化已经逆转。从 offset-md-* 到 ml-auto在 Bootstrap 4 Beta 1 中, offset-md-*...
    编程 发布于2024-12-23

免责声明: 提供的所有资源部分来自互联网,如果有侵犯您的版权或其他权益,请说明详细缘由并提供版权或权益证明然后发到邮箱:[email protected] 我们会第一时间内为您处理。

Copyright© 2022 湘ICP备2022001581号-3