」工欲善其事,必先利其器。「—孔子《論語.錄靈公》
首頁 > 程式設計 > 建構常規等變 CNN 的原則

建構常規等變 CNN 的原則

發佈於2024-07-31
瀏覽:835

The one principle is simply stated as 'Let the kernel rotate' and we will focus in this article on how you can apply it in your architectures.

Equivariant architectures allow us to train models which are indifferent to certain group actions.

To understand what this exactly means, let us train this simple CNN model on the MNIST dataset (a dataset of handwritten digits from 0-9).

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

Table 1: Test accuracy of the SimpleCNN model

As expected, we get over 95% accuracy on the testing dataset, but what if we rotate the image by 90 degrees? Without any countermeasures applied, the results drop to just slightly better than guessing. This model would be useless for general applications.

In contrast, let us train a similar equivariant architecture with the same number of parameters, where the group actions are exactly the 90-degree rotations.

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

Table 2: Test accuracy of the EqCNN model with the same amount of parameters as the SimpleCNN model

The accuracy remains the same, and we did not even opt for data augmentation.

These models become even more impressive with 3D data, but we will stick with this example to explore the core idea.

In case you want to test it out for yourself, you can access all code written in both PyTorch and JAX for free under Github-Repo, and training with Docker or Podman is possible with just two commands.

Have fun!

So What is Equivariance?

Equivariant architectures guarantee stability of features under certain group actions. Groups are simple structures where group elements can be combined, reversed, or do nothing.

You can look up the formal definition on Wikipedia if you are interested.

For our purposes, you can think of a group of 90-degree rotations acting on square images. We can rotate an image by 90, 180, 270, or 360 degrees. To reverse the action, we apply a 270, 180, 90, or 0-degree rotation respectively. It is straightforward to see that we can combine, reverse, or do nothing with the group denoted as C4C_4C4 . The image visualizes all actions on an image.

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
Figure 3: Feature maps for all four rotations after the input image was rotated

I color-coded the corresponding maps. Each feature map is shifted by one. As the final max operator computes the same result for these shifted feature maps, we obtain the same results.

In my code, I did not rotate back after the final convolution, since my kernels condense the image to a one-dimensional array. If you want to expand on this example, you would need to account for this fact.

Accounting for group actions or "kernel rotations" plays a vital role in the design of more sophisticated architectures.

Is it a Free Lunch?

No, we pay in computational speed, inductive bias, and a more complex implementation.

The latter point is somewhat solved with libraries such as E3NN, where most of the heavy math is abstracted away. Nevertheless, one needs to account for a lot during architecture design.

One superficial weakness is the 4x computational cost for computing all rotated feature layers. However, modern hardware with mass parallelization can easily counteract this load. In contrast, training a simple CNN with data augmentation would easily exceed 10x in training time. This gets even worse for 3D rotations where data augmentation would require about 500x the training amount to compensate for all possible rotations.

Overall, equivariance model design is more often than not a price worth paying if one wants stable features.

What is Next?

Equivariant model designs have exploded in recent years, and in this article, we barely scratched the surface. In fact, we did not even exploit the full C4C_4C4 group yet. We could have used full 3D kernels. However, our model already achieves over 95% accuracy, so there is little reason to go further with this example.

Besides CNNs, researchers have successfully translated these principles to continuous groups, including SO(2)SO(2)SO(2) (the group of all rotations in the plane) and SE(3)SE(3)SE(3) (the group of all translations and rotations in 3D space).

In my experience, these models are absolutely mind-blowing and achieve performance, when trained from scratch, comparable to the performance of foundation models trained on multiple times larger datasets.

Let me know if you want me to write more on this topic.

Further References

In case you want a formal introduction to this topic, here is an excellent compilation of papers, covering the complete history of equivariance in Machine Learning.
AEN

I actually plan to create a deep-dive, hands-on tutorial on this topic. You can already sign up for my mailing list, and I will provide you with free versions over time, along with a direct channel for feedback and Q&A.

See you around :)

版本聲明 本文轉載於:https://dev.to/freiberg-roman/the-1-principle-to-build-regular-equivariant-cnns-338b?1如有侵犯,請聯絡[email protected]刪除
最新教學 更多>
  • 如何擴展 JavaScript 中的內建錯誤物件?
    如何擴展 JavaScript 中的內建錯誤物件?
    擴充 JavaScript 中的 Error要擴充 JavaScript 中的內建 Error 對象,您可以使用 extends 關鍵字定義 Error 的子類別。這允許您使用附加屬性或方法建立自訂錯誤。 在 ES6 中,您可以定義自訂錯誤類,如下所示:class MyError extends E...
    程式設計 發佈於2024-11-03
  • 將測試集中在網域上。 PHPUnit 範例
    將測試集中在網域上。 PHPUnit 範例
    介紹 很多時候,開發人員嘗試測試 100%(或幾乎 100%)的程式碼。顯然,這是每個團隊應該為他們的專案達到的目標,但從我的角度來看,只應該完全測試整個程式碼的一部分:您的網域。 域基本上是程式碼中定義項目實際功能的部分。例如,當您將實體持久保存到資料庫時,您的網域不負責將其持...
    程式設計 發佈於2024-11-03
  • 如何使用 SQL 搜尋列中的多個值?
    如何使用 SQL 搜尋列中的多個值?
    使用 SQL 在列中搜尋多個值建立搜尋機制時,通常需要在同一列中搜尋多個值場地。例如,假設您有一個搜尋字串,例如“Sony TV with FullHD support”,並且想要使用該字串查詢資料庫,將其分解為單字。 透過利用 IN 或 LIKE 運算符,您可以實現此功能。 使用 IN 運算子IN...
    程式設計 發佈於2024-11-03
  • 如何安全地從 Windows 登錄讀取值:逐步指南
    如何安全地從 Windows 登錄讀取值:逐步指南
    如何安全地從Windows 註冊表讀取值檢測登錄項目是否存在確定登錄項目是否存在: LONG lRes = RegOpenKeyExW(HKEY_LOCAL_MACHINE, L"SOFTWARE\\Perl", 0, KEY_READ, &hKey); if (lRes...
    程式設計 發佈於2024-11-03
  • Staat原始碼中的useBoundStoreWithEqualityFn有解釋。
    Staat原始碼中的useBoundStoreWithEqualityFn有解釋。
    在這篇文章中,我們將了解Zustand原始碼中useBoundStoreWithEqualityFn函數是如何使用的。 上述程式碼摘自https://github.com/pmndrs/zustand/blob/main/src/traditional.ts#L80 useBoundStoreWi...
    程式設計 發佈於2024-11-03
  • 如何使用 Go 安全地連接 SQL 查詢中的字串?
    如何使用 Go 安全地連接 SQL 查詢中的字串?
    在Go 中的SQL 查詢中連接字串雖然文字SQL 查詢提供了一種簡單的資料庫查詢方法,但了解將字串文字與值連接的正確方法至關重要以避免語法錯誤和類型不匹配。 提供的查詢語法:query := `SELECT column_name FROM table_name WHERE colu...
    程式設計 發佈於2024-11-03
  • 如何在 Python 中以程式設計方式從 Windows 剪貼簿檢索文字?
    如何在 Python 中以程式設計方式從 Windows 剪貼簿檢索文字?
    以程式設計方式存取Windows 剪貼簿以在Python 中進行文字擷取Windows 剪貼簿充當資料的臨時存儲,從而實現跨應用程式的無縫數據共享。本文探討如何使用 Python 從 Windows 剪貼簿檢索文字資料。 使用 win32clipboard 模組要從 Python 存取剪貼簿,我們可...
    程式設計 發佈於2024-11-03
  • 使用 MySQL 預存程序時如何存取 PHP 中的 OUT 參數?
    使用 MySQL 預存程序時如何存取 PHP 中的 OUT 參數?
    使用MySQL 預存程序存取PHP 中的OUT 參數使用MySQL 儲存程序存取PHP 中的OUT 參數使用PHP 在MySQL 中處理預存程序時,取得由於文件有限,「 OUT”參數可能是一個挑戰。然而,這個過程可以透過利用 mysqli PHP API 來實現。 使用mysqli$mysqli =...
    程式設計 發佈於2024-11-03
  • 在 Kotlin 中處理 null + null:會發生什麼事?
    在 Kotlin 中處理 null + null:會發生什麼事?
    在 Kotlin 中處理 null null:會發生什麼事? 在 Kotlin 中進行開發時,您一定會遇到涉及 null 值的場景。 Kotlin 的 null 安全方法眾所周知,但是當您嘗試新增 null null 時會發生什麼?讓我們來探討一下這個看似簡單卻發人深省的情況吧! ...
    程式設計 發佈於2024-11-03
  • Python 字串文字中「r」前綴的意思是什麼?
    Python 字串文字中「r」前綴的意思是什麼?
    揭示「r」前綴在字串文字中的作用在Python中創建字串文字時,你可能遇到過神秘的“r” ” 前綴。此前綴具有特定的含義,可能會影響字串的解釋,尤其是在處理正則表達式時。“r”前綴表示該字串應被視為「原始」字串。 &&&]在常規字串中,轉義序列如\ n 和\t 被解釋為表示特殊字符,例如換行符和製表...
    程式設計 發佈於2024-11-03
  • 如何解決舊版 Google Chrome 的 Selenium Python 中的「無法找到 Chrome 二進位」錯誤?
    如何解決舊版 Google Chrome 的 Selenium Python 中的「無法找到 Chrome 二進位」錯誤?
    在舊版Google Chrome 中無法使用Selenium Python 查找Chrome 二進位錯誤在舊版Google Chrome 中使用Python 中的Selenium 時,您可能會遇到以下錯誤:WebDriverException: unknown error: cannot find ...
    程式設計 發佈於2024-11-03
  • `.git-blame-ignore-revs` 忽略批量格式變更。
    `.git-blame-ignore-revs` 忽略批量格式變更。
    .git-blame-ignore-revs 是 2.23 版本中引入的一项 Git 功能,允许您忽略 git Blame 结果中的特定提交。这对于在不改变代码实际功能的情况下更改大量行的批量提交特别有用,例如格式更改、重命名或在代码库中应用编码标准。通过忽略这些非功能性更改,gitblame 可以...
    程式設計 發佈於2024-11-03
  • 掌握函數參數:JavaScript 中的少即是多
    掌握函數參數:JavaScript 中的少即是多
    嘿,開發者們! ?今天,讓我們深入探討編寫乾淨、可維護的 JavaScript 的關鍵方面:管理函數參數 太多參數的問題 你有遇過這樣的函數嗎? function createMenu(title, body, buttonText, cancellable, theme, fon...
    程式設計 發佈於2024-11-03
  • 如何使用 FastAPI WebSockets 維護 Jinja2 範本中的即時評論清單?
    如何使用 FastAPI WebSockets 維護 Jinja2 範本中的即時評論清單?
    使用FastAPI WebSockets 更新Jinja2 範本中的項目清單在評論系統中,維護最新的評論清單至關重要提供無縫維護的使用者體驗。當新增評論時,它應該反映在模板中,而不需要手動重新加載。 在Jinja2中,更新評論清單通常是透過API呼叫來實現的。然而,這種方法可能會引入延遲並損害使用者...
    程式設計 發佈於2024-11-03
  • 掌握 SQL 查詢:&#教師薪資格查詢&# 項目
    掌握 SQL 查詢:&#教師薪資格查詢&# 項目
    您是否希望提升 SQL 技能並學習如何有效管理 MySQL 資料庫? LabEx 提供的教師薪資格式查詢專案就是您的最佳選擇。這個綜合計畫將引導您完成在大學資料庫中查詢和格式化教職員工薪資的過程,為您提供必要的知識和技能,以在資料管理工作中脫穎而出。 介紹 在這個引人入勝的專案中,...
    程式設計 發佈於2024-11-03

免責聲明: 提供的所有資源部分來自互聯網,如果有侵犯您的版權或其他權益,請說明詳細緣由並提供版權或權益證明然後發到郵箱:[email protected] 我們會在第一時間內為您處理。

Copyright© 2022 湘ICP备2022001581号-3