在训练大规模语言模型时,有意削弱其记忆能力有时反而能催生更智能的模型行为!
大语言模型若缺乏有效约束,极易完整复现训练数据中的内容,引发隐私与泛化问题。为此,来自马里兰大学、图宾根大学及马克斯·普朗克研究所的科研团队推出一种创新解决方案——金鱼损失(Goldfish Loss)。
该方法顾名思义,旨在让模型模仿金鱼的短暂记忆特性,避免机械式死记硬背每个细节,其核心是在损失函数计算过程中随机屏蔽一小部分token。
通过这一机制,模型不再逐字记忆训练集文本,转而专注学习语言的内在规律与模式。
实验数据显示,LLaMA-2模型应用金鱼损失后表现如下:
正如网友的生动比喻:这本质上是损失函数层面的dropout!
金鱼损失的根本原理相当直观,即在模型训练阶段随机剔除训练文本中的部分tokens,使其不参与损失值的计算。
这种设计使得模型在推理阶段遇到这些被屏蔽位置时,只能依据上下文进行合理推断,而非直接输出记忆中完整的训练序列。
此外,为确保被屏蔽token的一致性,研究团队开发了一套基于哈希(hashing)的掩码方案。
那么,金鱼损失与同样用于防止过拟合的正则化方法有何区别?
以Dropout这类经典正则化技术为例,它通过在训练过程中随机“注入噪声”来降低模型对特定神经元的依赖,从而提升泛化能力。
但这种方法存在局限:如果仅是随机丢弃token,模型在多次见到同一段落时,由于丢弃位置不同,可能通过累积学习拼凑出完整内容。
换言之,模型依然依赖于对训练数据的机械记忆。
相比之下,金鱼损失采用哈希掩码确保每次遇到相同段落时,被屏蔽的位置完全一致,这从源头切断了模型复现完整训练文本的可能性。
接下来,我们深入剖析金鱼损失的具体实现方式。
在传统的next-token prediction任务中,模型以序列中的下一个真实token作为目标,输出预测分布,并基于该分布计算交叉熵损失。
在金鱼损失框架下,模型在前向传播中依然预测序列的下一个token,但在计算损失时,会以特定概率将某些位置的token从损失计算中“移除”。
这意味着,部分真实的下一个token不会作为训练目标来优化模型。
研究中,团队首先使用了简单的静态掩码(static mask),例如固定剔除每个序列中的第4个token。
为进一步防止模型从其他上下文(如重复出现的网页文档)中学习到被掩码信息,他们又提出了局部化哈希掩码(localized hashed mask),使得当相同的前h个token出现时,掩码模式保持一致(可重复),从而强化了记忆抑制效果。
为证实金鱼损失能有效抑制记忆化,研究团队设置了两种实验环境:
一种是极端场景,通过对少量样本进行多轮次重复训练,强力诱导记忆行为;
另一种是标准场景,模拟实际训练中常见的批次处理流程。
同时,为量化模型的记忆程度,研究采用了以下评估指标:
RougeL得分:该指标通过计算最长公共子序列来反映记忆完整性,得分1.0代表完美复现。
精确匹配率(Exact Match):该指标衡量模型输出与原始序列完全一致的百分比。
实验结果显示,在极端场景下,标准训练导致模型逐字记忆了100篇文章中的84篇,而金鱼损失则成功避免了任何文章的完整记忆。
(注:实验基于LLaMA-2-7B模型,在《哈利·波特》第一章或100篇维基百科文档上额外训练了100个epoch)
此外,在标准训练场景下,金鱼损失也显著减少了模型对训练语料库中目标序列的逐字复现现象。
一个自然的疑问是:让模型“随机忽略”部分token是否会导致其能力下降?
针对此,研究人员进行了对比测试:结果表明,金鱼损失模型、标准损失模型以及对照模型在总体性能上并未出现系统性差异。
需要指出的是,金鱼损失的核心在于忽略部分token的梯度计算,因此模型需要通过接触更多数据来弥补这些信息空缺,这可能在一定程度上增加训练计算成本。
参考链接
[1]https://arxiv.org/pdf/2406.10209
本文由主机测评网于2025-12-28发表在主机测评网_免费VPS_免费云服务器_免费独立服务器,如有疑问,请联系我们。
本文链接:https://vpshk.cn/20251213441.html