CoAtNet(NeurIPS 2023, Google)论文解读

paper:CoAtNet: Marrying Convolution and Attention for All Data Sizes

third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/maxxvit.py

背景

自AlexNet以来,ConvNets一直是计算机视觉领域的主流模型。 Transformers在自然语言处理取得成功后,许多研究尝试将其引入计算机视觉领域。尽管Vision Transformer (ViT) 取得了一些成果,但其在小数据集上的表现仍不如ConvNets。

作者认为,Transformers可能缺乏卷积网络所拥有的某些理想的归纳偏差(inductive bias),这导致它们需要大量的数据和计算资源来补偿。因此本文主要讨论了如何将卷积神经网络(ConvNets)和自注意力机制(Transformers)结合在一起,以实现更好的图像分类性能。

创新点

该研究旨在解决以下问题:

  1. 如何在一个基本计算模块内结合卷积和自注意力机制。
  2. 如何垂直堆叠不同类型的计算模块,形成一个完整的网络。

创新点包括:

  1. 提出了深度卷积(depthwise Convolution)和自注意力(self-Attention)可以通过简单的相对注意力(relative attention)实现统一。
  2. 通过合理的方式垂直堆叠卷积层和注意力层,可以显著提高模型的泛化能力和容量。
  3. 提出了 CoAtNet 架构,它结合了 ConvNets 和 Transformers 的优点。

效果

  • 未使用额外数据时,CoAtNet达到了86.0%的ImageNet top-1准确率。
  • 在ImageNet-21K数据集(1300万张图像)上进行预训练后,CoAtNet达到了88.56%的top-1准确率,与使用300M张图像进行预训练的ViT-Huge相当,但数据量减少了23倍。
  • 在JFT-3B数据集上进行预训练后,CoAtNet达到了90.88%的top-1准确率,创下了新的记录。

方法介绍

Merging Convolution and Self-Attention

对于卷积作者主要关注MBConv block,它使用深度卷积来捕获空间相互作用。选择它的原始是Transforme中的FFN和MBConv block一样都采用了"inverted bottleneck"的设计。

深度卷积和self-attention都可以表示为一个在预先定义的感受野内进行每个维度值的加权求和过程。具体来说,卷积依赖一个固定的kernel从一个局部感受野内收集信息

$$ y_i=\sum_{j \in \mathcal{L}(i)} w_{i-j} \odot x_j \quad \text { (depthwise convolution), }  \qquad \tag1 $$ 

其中 \(x_i,y_x\in \mathbb{R}^D\) 分别是位置 \(i\) 处的输入和输出,\(\mathcal{L}(i)\) 表示 \(i\) 的一个局部邻域,比如中心点为 \(i\) 的一个3x3方格。

相比之下self-attention的感受野为全部空间位置

$$ y_i=\sum_{j \in \mathcal{G}} \underbrace{\frac{\exp \left(x_i^{\top} x_j\right)}{\sum_{k \in \mathcal{G}} \exp \left(x_i^{\top} x_k\right)}}_{A_{i, j}} x_j \quad \text{ (self-attention)}, \qquad \tag2 $$

其中 \(\mathcal{G}\) 表示全局位置空间。

在讨论如何更好地组合它们之前,我们先比较一下它们的相对优势和劣势,这有助于找到我们希望保留的特性。

  • 首先深度卷积核 \(w_{i-j}\) 是一个不依赖于输入的静态参数,而attention权重 \(A_{i,j}\) 则动态地依赖输入表示。因此self-attention更容易捕获不同空间位置之间复杂的交互关系,但这种灵活性也更容易过拟合,特别是在数据有限的情况下。
  • 其次给定任意一对位置 \((i,j)\),对应的卷积权重 \(w_{i-j}\) 只关心它们之间的相对位移即 \(i-j\) 而不关心 \(i,j\) 的具体值。这就是我们常说的平移不变性,这一特性可以提高有限数据下模型的泛化性。由于使用了绝对位置编码ViT缺乏这一特性,这也解释了为什么在数据量有限时ConvNets的效果比Transformers要好。

  • 相比于卷积的局部感受野,Transformer具有全局感受野,更大的感受野提供了更多的上下文信息,提高了模型的容量,同时也需要更多的计算量。

根据上面的分析,一个理想的模型应该同时具备表1中的三点特性。根据式(1)和式(2),一个直接的想法是将一个全局静态卷积核和一个动态注意力矩阵相加,在softmax函数之前或之后都可以,如下

$$y_i^{\text {post }}=\sum_{j \in \mathcal{G}}\left(\frac{\exp \left(x_i^{\top} x_j\right)}{\sum_{k \in \mathcal{G}} \exp \left(x_i^{\top} x_k\right)}+w_{i-j}\right) x_j \quad or \quad y_i^{\mathrm{pre}}=\sum_{j \in \mathcal{G}} \frac{\exp \left(x_i^{\top} x_j+w_{i-j}\right)}{\sum_{k \in \mathcal{G}} \exp \left(x_i^{\top} x_k+w_{i-k}\right)} x_j \tag3$$

本文采用 \(y_i^{pre}\),其实这就是一种relative self-attention,和swin-transformer中的relative position bias是一模一样。

Vertical Layout Design

在找到了如何将卷积和注意力结合起来后,作者研究了如何stack layers来构建网络。

由于attention的计算量和输入分辨率是二次方关系,直接将式(3)应用于原始输入会导致计算量过大。因此作者采用了对原始输入进行降采样,在分辨率到了一个可控水平后再应用global relative attention的方式。对于降采样有两种方式,一是像原始的ViT那样直接采用一个大步长stride=16的卷积进行降采样,二是像ConvNets那样多个stage逐步降采样。

作者首先给出了5种设计选项,然后通过实验来比较效果。第一种是应用ViT stem,然后堆叠 \(L\) 个relative attention的Transformer block,将其表示为 \(\mathbf{ViT}_{REL}\)。

然后是multi-stage的方式,一共包含5个stage(S0, S1, S2, S3, S4),在每个stage的开始进行2x的降采样。S0是一个2层卷积的stem,S1是一个带有SE的MBConv block,前两个stage的设计是固定的。然后S2到S4我们使用MBConv或Transformer block,其中保证卷积stage一定在Transformer stage前面,这样我们就得到了四种不同的变体,C-C-C-C、C-C-C-T、C-C-T-T、C-T-T-T,其中C和T分别表示卷积和Transformer。

作者在ImageNet-1K和JFT数据集上进行了模型容量和泛化性能的比较,结果如图1所示

泛化性能表现

$$\mathrm{C}-\mathrm{C}-\mathrm{C}-\mathrm{C} \approx \mathrm{C}-\mathrm{C}-\mathrm{C}-\mathrm{T} \geq \mathrm{C}-\mathrm{C}-\mathrm{T}-\mathrm{T}>\mathrm{C}-\mathrm{T}-\mathrm{T}-\mathrm{T} \gg \mathrm{VIT}_{\mathrm{REL}}$$

模型容量对比

$$\mathrm{C}-\mathrm{C}-\mathrm{T}-\mathrm{T} \approx \mathrm{C}-\mathrm{T}-\mathrm{T}-\mathrm{T}>\mathrm{VIT}_{\mathrm{REL}}>\mathrm{C}-\mathrm{C}-\mathrm{C}-\mathrm{T}>\mathrm{C}-\mathrm{C}-\mathrm{C}-\mathrm{C}$$

最后在C-C-T-T和C-T-T-T之间选择,作者又做了一个迁移性测试,两个在JFT上预训练的模型在ImageNet-1K上微调30个epoch,然后比较它们的性能,结果如表2所示

可以看到C-C-T-T的迁移性明显更好,最终CoAtNet选择了C-C-T-T multi-stage的设计。

实验结果

不同大小的CoAtNet的配置如下

在ImageNet-1k上和其它模型的对比如下

代码解析

代码没什么好讲的,其中式(3)的 \(y_i^{pre}\) 就是swin transformer中的relative position bias,代码都是一样的,可以参考Swin Transformer(ICCV 2021)论文与代码解析-CSDN博客。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/769073.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【接口自动化测试】第四节.实现项目核心业务的单接口自动化测试

文章目录 前言一、登录单接口自动化测试 1.1 登录单接口文档信息 1.2 登录成功 1.3 登录失败(用户名为空)二、数据驱动的实现 2.1 json文件实现数据驱动三、课程添加单接口自动化测试 3.1 课程添加单接口文档信息 3.2 课程…

N5 使用Gensim库训练Word2Vec模型

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊# 前言 前言 这周学习训练一个Word2Vec模型,并进行一些基本的词向量操作。 Word2Vec 模型 Word2Vec 是一种基于神经网络的词向量表示方法&#x…

Qt Q_ASSERT详解

Q_ASSERT详解 引言一、基本用法二、深入了解三、参考链接 引言 Q_ASSERT是 Qt 框架中的一个宏,用于在调试时检查某个条件是否为真。它是程序调试中的一个重要工具,有助于开发者在开发过程中及时发现并修复潜在的错误。 一、基本用法 只在使用 Qt 的 D…

API 授权最佳实践

API(应用程序编程接口)就像秘密之门,允许不同的软件程序进行通信。但并不是每个人都应该拥有每扇门的钥匙,就像不是每个软件都应该不受限制地访问每个 API 一样。 这些 API 将从银行的移动应用程序到您最喜欢的社交媒体平台的所有…

嵌入式C语言中指针与链表的关系详解

假定给你一块非常小的内存,这块内存只有8字节,这里也没有高级语言,没有操作系统,你操作的数据单位是单个字节,你该怎样读写这块内存呢? 注意这里的限定,再读一遍,没有高级语言,没有操作系统,在这样的限制之下,你必须直面内存读写的本质。 这个本质是什么呢? 本质…

Vuex的基本使用

1.安装vuex npm i vuex3 2.引入 import Vuex from vuex 3.使用 Vue.use(Vuex) 4.在src下的目录创建store,新建index.js import store from ./store 5.编写index.js import Vue from vue import Vuex from vuex Vue.use(Vuex)//用于操作组件中的动作 const actions{a…

Linux安装Node-RED并实现后台运行及开机启动

首先确保系统中已近成功安装Node.js,并保证需要的合适版本: 关于node.js的安装可以参考我的另一篇博文:《AliyunOS安装Node.js》。 然后就可以使用npm工具安装Node-RED了,很简单使用如下命令: sudo npm install -g --unsafe-per…

antd Select前端加模糊搜索

背景&#xff1a;前端的小伙伴经常在开发antd Select的时候后端不提供搜索模糊搜索接口&#xff0c;而是全量返回数据&#xff0c;这个时候就需要我们前端自己来写一个模糊搜索了。 效果 代码截图 代码 <SelectshowSearchmode"multiple"options{studioList}filte…

视频分析、目标检测的过去和未来:目标检测从入门到精通 ------ YOLOv8 到 多模态大模型处理视觉基础任务

文章大纲 计算机视觉项目的关键步骤目标检测入门视频分析项目最佳实践数据集构建数据准备:数据集标注规范与数据规模参考标注工具标注工具:目标检测yolo 极简标注工具综合标注工具:label-studio半自动标注工具:X-AnyLabeling目标检测与多模态哪些多模态模型可以做目标检测?…

构建安全稳定的应用:Spring Security 实用指南

前言 在现代 Web 应用程序中&#xff0c;安全性是至关重要的一个方面。Spring Security 作为一个功能强大且广泛使用的安全框架&#xff0c;为 Java 应用程序提供了全面的安全解决方案。本文将深入介绍 Spring Security 的基本概念、核心功能以及如何在应用程序中使用它来实现…

招聘应聘,HR如何测试候选人的领导能力?

作为企业的HR&#xff0c; 如何通过测评的方式来了解一个人的领导能力&#xff1f; 这里仅仅是说测评的方式&#xff0c;除此以外&#xff0c;还有很多方式&#xff0c;比如&#xff1a;背景调查&#xff0c;无领导小组讨论等等..... 对于一个人的领导能力测试&#xff0c;主要…

网页报错dns_probe_possible 怎么办?——错误代码有效修复

当你在浏览网页时遇到dns_probe_possible 错误&#xff0c;这通常意味着你的浏览器无法解析域名系统&#xff08;DNS&#xff09;地址。这个问题可能是由多种原因引起的&#xff0c;包括网络配置问题、DNS服务问题、或是本地设备的问题。教大家几种修复网页报错dns_probe_possi…

ctfshow-xss(web316-web330)

讲解相当细致 精致练习XSS web316 这道题估计陆陆续续弄了半天 因为xss可以说基本不会 还好最终彻彻底底明白了 首先这道题是反射性xss 也就是必须点击某一个xss链接 才能达到xss效果 这道题的意思就是 写一个祝福语生成链接发送给朋友 这个祝福语的位置就是我们实现XSS的位…

GPT-4预测股票涨跌更更更准了!东京大学新框架LLMFactor提升显著 | ACL 2024

花一秒钟就看透事物本质的人&#xff0c;和花一辈子都看不清的人&#xff0c;注定是截然不同的命运。——唐柯里昂 除了少数天纵奇才&#xff0c;大多数人都是通过知识和阅历的不断积累&#xff0c;才逐渐锻炼出观察和判断事物变化规律的能力。而如果说有一件事&#xff0c;可以…

代码便利工具

【原创】PyCharm 安装MarkDown插件&#xff0c;并修改.md文件默认打开方式_pycharm如何修改markdown-CSDN博客 1.上面是填写README的工具。

DeepFaceLive----AI换脸简单使用

非常强大的软件,官方github https://github.com/iperov/DeepFaceLive 百度云链接: 链接&#xff1a;https://pan.baidu.com/s/1VHY-wxqJXSh5lCn1c4whZg 提取码&#xff1a;nhev 1下载解压软件 下载完成后双击.exe文件进行解压.完成后双击.bat文件打开软件 2 视频使用图片换…

JAVA+SSM+VUE《病人跟踪治疗信息管理系统》

1病人功能模块 病人登录进入病人跟踪治疗信息管理系统可以查看首页、个人中心、病例采集管理、预约管理、医生管理、上传核酸检测报告管理、上传行动轨迹管理、病人治疗状况管理等内容。 病例采集管理&#xff0c;在病例采集管理页面可以查看账号、姓名、住院号、入院时间、病…

2024鲲鹏昇腾创新大赛集训营Ascend C算子学习笔记

异构计算架构&#xff08;CANN&#xff09; 对标英伟达的CUDA CuDNN的核心软件层&#xff0c;向上支持多种AI框架&#xff0c;向下服务AI处理器&#xff0c;发挥承上启下的关键作用&#xff0c;是提升昇腾AI处理器计算效率的关键平台。主要包括有各种引擎、编译器、执行器、算…

[leetcode hot 150]第三题,无重复字符的最长子串

题目&#xff1a; 给定一个字符串 s &#xff0c;请你找出其中不含有重复字符的 最长 子串的长度。 可以使用"滑动窗口"的方法来解决这个问题。基本思路如下: 使用两个指针(start和end)来定义一个窗口移动end指针来扩大窗口,直到遇到重复字符如果遇到重复字符,移动s…

Spring源码九:BeanFactoryPostProcessor

上一篇Spring源码八&#xff1a;容器扩展一&#xff0c;我们看到ApplicationContext容器通过refresh方法中的prepareBeanFactory方法对BeanFactory扩展的一些功能点&#xff0c;包括对SPEL语句的支持、添加属性编辑器的注册器扩展解决Bean属性只能定义基础变量的问题、以及一些…