申请/专利权人:西南交通大学
申请日:2023-09-14
公开(公告)日:2024-04-09
公开(公告)号:CN117216566B
主分类号:G06F18/214
分类号:G06F18/214;G06N3/0895;G06N3/098
优先权:
专利状态码:有效-授权
法律状态:2024.04.09#授权;2023.12.29#实质审查的生效;2023.12.12#公开
摘要:本发明公开了一种基于局部‑全局伪标记的联邦半监督学习方法,步骤S1、在通信轮次t开始时,服务器将全局模型参数传输到活动客户端;每个客户端在接收到参数后再利用全局模型和上一轮通信轮次t‑1中训练得到的本地模型在未标记数据的弱增强视图上生成伪标签,并将其作为本地训练强增强视图的目标用于优化交叉熵损失;S2、本地训练结束后每个客户端将本地模型的参数发回到服务器,服务器聚合这些参数并对其进行微调,最后得到一个新的全局模型上述交替训练过程重复多次至全局模型收敛后结束。
主权项:1.一种基于局部-全局伪标记的联邦半监督学习方法,其特征在于,包括以下两个步骤:S1、在通信轮次t开始时,服务器将全局模型参数传输到活动客户端;每个客户端在接收到参数后再利用全局模型和上一轮通信轮次t-1中训练得到的本地模型在未标记数据的弱增强视图上生成伪标签,并将其作为本地训练强增强视图的目标用于优化交叉熵损失;其中,生成伪标签的方法如下:对于客户端Cu的未标记数据集使用公式1和2一次性标记数据集内所有数据,并通过下列公式3的方式构建一个固定的伪标记数据集 式中,I·是一个指示函数,DAw是弱数据增强操作,xu是未标记数据,是通信轮次t中全局模型的参数,F·是一个基于卷积神经网络的编码器,是置信度阈值;是上一轮通信轮次t-1中训练得到的本地模型参数,L是未标记数据集中的类别总数;如果数据集为空,则该客户端的训练过程将直接被跳过;否则,则从数据集Du中随机采样一个与等大小的数据集用于辅助训练,其定义如下: 其中,是生成的伪标签,是数据集的大小;在客户端的本地训练过程中,数据集和被随机划分为大小为Bu的小批次数据;从数据集和被划分的小批次数据中分别采样一对样本和然后利用线性插值方法构建一对新的样本: 其中,Beta·表示Beta分布,α是其对应的超参数;λ是Beta分布生成的一个数值,代表插值方法构造出的新数据,i表示索引下标,u是unlabeled缩写,代表未标记的数据;对于插值数据定义如下训练目标: 对于伪标记数据集定义如下训练目标: 其中,CE表示交叉熵损失,DAs是一种强数据增强操作;未标记客户端上总的训练目标如下所示: 其中λm是决定相对权重的超参数;S2、本地训练结束后,每个客户端将本地模型的参数发回到服务器,服务器首先聚合这些参数,然后利用标记数据集对其进行微调,最后得到一个新的全局模型上述交替训练过程重复多次至全局模型收敛后结束。
全文数据:
权利要求:
百度查询: 西南交通大学 一种基于局部-全局伪标记的联邦半监督学习方法
免责声明
1、本报告根据公开、合法渠道获得相关数据和信息,力求客观、公正,但并不保证数据的最终完整性和准确性。
2、报告中的分析和结论仅反映本公司于发布本报告当日的职业理解,仅供参考使用,不能作为本公司承担任何法律责任的依据或者凭证。