书生大模型公式识别打榜赛参赛记录
kafm

参赛记录

本文记录参加书生大模型社区比赛的一些过程。

比赛简介

本次比赛是上海 AI Lab 举办的书生大模型实战营(第六期)的社区活动,在算力平台 d.run 上使用沐曦算力,通过微调 VLM 等方法识别输入的公式图片,输出对应的 LaTex 文本。限定使用 InternVL3.5-1B 模型。

之前也有过类似比赛,实战营(第五期)举办了论文分类打榜赛。微调 LLM 对论文摘要进行学科分类。

任务

具体来说,输入图片均为 texlive 渲染得到的 LaTex 公式图片,输出应为对应的 LaTex 公式文本。
例如对于输入:
image
期望的输出:

1
\sum_{i=1}^{\infty} \frac{1}{i^2} = \frac{\pi^2}{6} \quad \text{and} \quad \left\| \mathbf{A} \right\| = \sqrt{\lambda_{\max}(\mathbf{A}^T\mathbf{A})} \quad \text{where} \quad \mathbf{A} = \begin{bmatrix} \int_{0}^{1} x^2 dx & \frac{1}{2} \\ 2 & \int_{0}^{2} e^{-x} dx \end{bmatrix}

评估方式

哈希比较成功率: 模型生成的图片与参考图片哈希值完全相同的样本比例。
相似度比较成功率: 模型生成的图片与参考图片的图像相似度(直方图相似度/SSIM/MSE/特征点相似度 加权)高于阈值的样本比例。
最终综合得分: 上述两项成功率的加权平均值,全面反映模型的性能。

提示:测试数据的图像可能经过增强

点击展开详情

本次比赛可能会涉及一些基础图像增强变换,包括但不限于:

  1. 几何与空间变换
  • 旋转 (Rotation)
  • 透视变换 (Perspective)
  • 仿射变换 (Affine: 含缩放与错切)
  • 镜头畸变 (Lens Distortion)
  • 画布扩展与边缘裁剪 (Canvas Expansion & Border Trimming)
  1. 颜色与光照调整
  • 颜色抖动 (Color Jitter: 含亮度与对比度)
  • RGB 通道偏移 (RGB Shift)
  • 色温调整 (Color Temperature)
  • Gamma 校正 (Gamma Correction)
  • 通道随机丢失 (Channel Dropout)
  1. 噪声干扰
  • 高斯噪声 (Gaussian Noise)
  • 椒盐噪声 (Salt & Pepper Noise)
  • 泊松噪声 (Poisson Noise)
  • 散斑噪声 (Speckle Noise)
  1. 模糊与画质损伤
  • 高斯模糊 (Gaussian Blur)
  • 运动模糊 (Motion Blur)
  • JPEG 压缩伪影 (JPEG Artifacts)

评估数据包括两版,A榜 和 B榜,类似于验证集和测试集,样本均未公开。

资源

  • 代码:提供了两份 Lora SFT 的 baseline,基于 ms-swiftdemo 和基于 XTunerdemo,均包括评估框架和微调代码。
    实测第一个 demo 在A榜得分 56.50,B榜得分 48.60.(其实这个分数一下就能进前 60 了。。)

  • 数据:提供有一个包含 3000 个图像文本对的训练数据集,上面两个 demo 都使用该数据集。

  • 奖金:1st, 2nd, 3rd, 4-20th, 21-60th 分别有 6/4/3/2/1k RMB. 是真的我作证,因为上一期拿到了 1k.

分析

InternVL3.5-1B 是 LLaVA 式的 VLM, 这类的 VLM 约等于 Vison Backbone + MLP + LLM,其中 MLP 相当于视觉和语言的桥梁,把视觉 token 映射到视觉语义共享的特征空间,作为 LLM 的视觉输入。
跑完 baseline 后,简单来说有以下改进方向:

  • 数据:①增加训练数据(寻找/合成);②图像数据增强
  • 训练:③尝试不同的微调方式,如全量微调;调整微调不同部分的权重
  • 寻求更大的增益:④因为有明确的奖励信号,适合用 RL 优化模型偏好,把输出对齐到合法的 LaTex 文本;⑤通过后处理约束输出 LaTex 文本的合法性

过程

样本数量扩充
首先尝试了更大的开源数据集,直接扑街,A榜 1 分,B榜 0.2 分,统计命令出现频率发现数据分布天差地别,估计初始训练数据和 AB 榜相关性比较强。(后来举办方明确说明了初始训练数据和 AB 榜生成模式相同。好吧如果是我,我肯定给 B 榜数据构造点儿偏移)看了下初始数据,全是包含微分和矩阵的合成公式,本质上一个 LaTex 的高频子集,语法正确但不保证内容有意义。那么就根据初始数据的分布扩充数据。

第一版数据生成器比较粗糙,数据扩充到 9k 分数有提升。扩充到 20k 反而 A 榜分数有下降。这时由于评测时间过长,B 榜分数出现了严重的滞后,根据 A 榜表现放弃了 20k 版本的数据。
第二版数据生成器严格按照命令分布和长度生成数据,90k 数据效果依然拉跨,甚至不如 20k 版本。放弃。重心转移到下面几个方向。
这时还有两个可能没有探索:数据量上来后,是否该切到用全量微调;只约束了样本文本长度均值,没有约束分布。

数据清洗
观察文本命令分布时发现有很多不规则空格及换行符,感觉会干扰预测空间,与 AI 讨论后对文本做标准化,删除了所有换行符和多余空格。分数有提升。

数据增强
按照提示实现了类似 RandAug 的图像增强,调整某些增强的强度例如旋转角度在合理范围内。

多打几个补丁
Lora 仿佛在学习权重(而非特征)的残差,而且还能合并回原参数,像在预训练参数上打升级补丁。
那自然可以多打几个补丁强化任务适配,LLM,MLP, Vison, LLM+MLP 尝试了依次在这四个配置上微调,效果均有提升。
尝试了全量微调 LLM + MLP,A 榜效果等同 Lora 微调甚至略低,有点反直觉,可能是同时调整的参数太多。但这也符合(权重的)“残差”更易于学习的刻板印象。

小乌龙
提交允许自定义 prompt。什么?微调的 1B VLM 模型还有 prompt 工程价值?好吧确实可能有,我至少应该提示模型输出的公式语法要正确。于是设计了如下 prompt:

1
2
请根据图片中的公式生成对应的 latex 语法正确的公式文本。
*原 prompt:请根据图片中的公式生成对应的 latex 公式文本。

这是 native speaker 写出的中文句子?LLM 看了可能都不习惯,以后还是要多写作、多表达
结果,提交了八九次才发现第二版开始训练数据用的还是默认模板,搞笑。这个低级错误害我又提交了好几次重复权重来测 prompt 的影响,结论是好像没太大影响,B 榜波动小,A 榜波动大。

更进一步

一直瞎训提交 A 榜抽奖也没啥意思(主要是没啥进步空间了),不如趁这机会试试其他玩意

RL
对 RL 一直敬而远之,感觉需要大量时间和空间,事实证明确实如此。基于 trl 搭好 GRPO 框架,在 LLM 的辅助下设计了三个奖励函数,分别评估识别结果的语法正确性,和标签的编辑距离以及渲染结果和输入图片相似度(直接借用评测框架里的计算方式)。

小小的吐槽一下,怀疑比赛举办方没有 review 过 vibe coding 的评测框架代码,每种相似度评分都重复地进行预处理,以及实现了批量推理却没有使用,不知道是不是 B 榜出分缓慢的原因之一。

跑起来 GRPO 了,此时距比赛结束还有不到 48 小时,一看预计运行时间 40+ 小时,实际上更慢。中间还断了两次,感觉沐曦卡还不是特别稳定,国产算力任重道远。vllm 在沐曦上没找到合适环境,torch 版本太低,swift sample 也没调通,就龟速地一边推一遍训了。

log 里 clip_ratio=0、entropy 指标也在下降,Claude 说模型没学到啥,毕竟我看示例都是训了几百 ep,我这 1 ep 都跑不完,训了 600 steps,Lora 权重也没啥变化,有图为证,
image
RL 失败,按下不表。此时距比赛结束还有不到 30 小时.

约束解码/后处理
其实看到比赛题目时,有个疑问就浮在脑海:相似度评估真的有效吗
公式渲染方式一致,预处理方式一致,白/透明底黑字之间会有非常显著的差别吗?
对 10 张测试图像两两组合计算相似度均值,在默认加权中,相似度基本略高于 0.5-0.6,如果 B 榜评测时设置的相似度阈值较低,或者渲染成功就有分,那给 LLM 配个能正确渲染的保底输出,就能稳定提分。这里就要提到 transformers 的 AutoClass 类的 trust_remote_code 了,简单来说它允许加载模型仓库里的代码用于推理,给评估时的后处理提供了注入条件。实测默认返回一张初始训练集的样本 sample19996.png 结果是 A 榜 0 分,此路不通,推测相似度阈值应该在 0.7 甚至 0.75 以上。

按照约束解码的思路,根据完整的 LaTex 文法约束 token 采样,但 LaTex 语法略显复杂,先从错误样本做个简单版本:
推理 100 张测试图像有 6 个渲染错误,检查错误样本,归类错误原因。通过 logits processor 修复已知错误。于是请求 Claude,

在生成的公式中观察到的错误有:
1、没任何矩阵环境,生成了 &
2、生成不匹配的括号,例如 (]
3、编造的命令,例如:\vbeta
4、数字后出现下划线如 b_6_4
5、命令层级不匹配,如\begin{array} \begin{pmatrix} 后首先出现 \end{array}
6、…
请以类似约束解码的方式,举一反三地抑制不合法的输出,每种错误由一个函数来检查,使得易于检查和使用

Claude 创建了一个类 LatexState 记录生成过程的各种状态

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
@dataclass
class LatexState:
"""追踪当前已生成序列的语法状态"""
# 括号栈:追踪未闭合的括号
bracket_stack: list = field(default_factory=list)
# 环境栈:追踪 \begin{} \end{} 嵌套
env_stack: list = field(default_factory=list)
# 当前是否在矩阵环境中
in_matrix_env: bool = False
# 上一个有效 token 的文本
last_token_text: str = ""
# 已生成的完整文本
generated_text: str = ""

# 矩阵环境集合(允许使用 & 的环境)
MATRIX_ENVS = {
"matrix", "pmatrix", "bmatrix", "vmatrix", "Vmatrix",
"smallmatrix", "array", "aligned", "align", "cases",
"gather", "eqnarray",
}
# ...

基于 LatexState 的状态,定义了对应多种错误的检查函数,例如:

1
2
3
4
5
6
7
8
def check_ampersand_outside_matrix(candidate: str, state: LatexState) -> bool:
"""
错误类型1:没有矩阵环境时生成 &
返回 True 表示该 token 非法
"""
if '&' in candidate and not state.in_matrix_env:
return True
return False
点击展开其他检查函数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def check_mismatched_brackets(candidate: str, state: LatexState) -> bool:
"""
错误类型2:生成不匹配的括号,如 (]
模拟将 candidate 加入当前括号栈,检查是否产生不匹配
"""
PAIRS = {')': '(', ']': '[', '}': '{'}
OPEN = set('([{')
CLOSE = set(')]}')

# 模拟当前栈
simulated_stack = list(state.bracket_stack)
for ch in candidate:
if ch in OPEN:
simulated_stack.append(ch)
elif ch in CLOSE:
if not simulated_stack:
return True # 无对应开括号
if simulated_stack[-1] != PAIRS[ch]:
return True # 括号类型不匹配
simulated_stack.pop()
return False

def check_invalid_command(candidate: str, state: LatexState) -> bool:
"""
错误类型3:编造的命令,如 \vbeta, \invalidcmd
检测 candidate 中是否包含不在白名单的 LaTeX 命令
"""
commands = re.findall(r'\\([a-zA-Z]+)', candidate)
for cmd in commands:
if cmd not in VALID_COMMANDS:
return True
return False


def check_consecutive_subscripts(candidate: str, state: LatexState) -> bool:
"""
错误类型4:数字或字母后出现连续下划线,如 b_6_4
检测合并后文本中是否有 x_y_z 的非法连续下标(未用{}包裹)
"""
combined = state.generated_text + candidate
# x_y_z 形式(未用 {} 包裹时的连续下标)
if re.search(r'(?<!\{)[^_{}]+_[^_{}]+_', combined):
return True
return False


def check_env_order_mismatch(candidate: str, state: LatexState) -> bool:
"""
错误类型5:\end{env} 与栈顶 \begin{env} 不匹配
如 \begin{array}\begin{pmatrix} 后出现 \end{array}
"""
ends = re.findall(r'\\end\{(\w+)\}', candidate)
simulated_stack = list(state.env_stack)
for env in ends:
if not simulated_stack:
return True # 没有对应的 \begin
if simulated_stack[-1] != env:
return True # 与栈顶不匹配
simulated_stack.pop()
return False


def check_illegal_characters(candidate: str, state: LatexState) -> bool:
"""
错误类型6:非法 Unicode 字符(中文、全角等)
"""
ILLEGAL_RANGES = [
(0x4E00, 0x9FFF), # 中文基本汉字
(0x3400, 0x4DBF), # 中文扩展A
(0xFF00, 0xFFEF), # 全角字符
(0x3000, 0x303F), # 中文标点
(0x0080, 0x009F), # C1控制字符
]
for char in candidate:
cp = ord(char)
for start, end in ILLEGAL_RANGES:
if start <= cp <= end:
return True
return False


def check_double_superscript(candidate: str, state: LatexState) -> bool:
"""
错误类型7:连续上标,如 x^2^3(未用{}包裹)
"""
combined = state.generated_text + candidate
if re.search(r'\^[^{}\^]+\^', combined):
return True
return False


def check_empty_group(candidate: str, state: LatexState) -> bool:
"""
错误类型8:空的花括号组 {}(通常无意义,在关键位置会导致渲染错误)
"""
combined = state.generated_text + candidate
# \frac{}{} 或 \sqrt{} 等关键命令后的空组
if re.search(r'\\(frac|sqrt|hat|bar|vec|dot)\{\}', combined):
return True
return False

斯巴拉西,然后 Claude 给出了 logits processor 的实现

点击展开详情
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class LatexConstraintProcessor(LogitsProcessor):
def __init__(
self,
tokenizer,
enabled_checkers: Optional[set] = None, # None 表示启用全部
penalty: float = -float('inf'), # 非法 token 的惩罚值
verbose: bool = False,
):
self.tokenizer = tokenizer
self.penalty = penalty
self.verbose = verbose
self.state = LatexState()

# 选择启用的检查器
if enabled_checkers is None:
self.checkers = CHECKERS
else:
self.checkers = {k: v for k, v in CHECKERS.items() if k in enabled_checkers}

# 预计算词表中每个 token 的解码文本(避免每步重复 decode)
self._vocab_decoded = self._precompute_vocab(tokenizer)
print(f"[LatexConstraintProcessor] 启用检查器: {list(self.checkers.keys())}")

def _precompute_vocab(self, tokenizer) -> dict:
vocab = tokenizer.get_vocab()
decoded = {}
for token, idx in vocab.items():
try:
decoded[idx] = tokenizer.convert_tokens_to_string([token])
except Exception:
decoded[idx] = ""
return decoded

def __call__(
self,
input_ids: torch.LongTensor,
scores: torch.FloatTensor,
) -> torch.FloatTensor:

# 更新状态:解码最新生成的 token
if input_ids.shape[1] > 0:
last_token_id = input_ids[0, -1].item()
last_text = self._vocab_decoded.get(last_token_id, "")
self.state.update(last_text)

# 对每个候选 token 逐一检查
banned_ids = []
for token_id, candidate_text in self._vocab_decoded.items():
for checker_name, checker_fn in self.checkers.items():
if checker_fn(candidate_text, self.state):
banned_ids.append(token_id)
if self.verbose:
print(f"[banned] id={token_id} text={candidate_text!r} reason={checker_name}")
break # 一个检查失败即禁止,无需继续

if banned_ids:
scores[0, banned_ids] = self.penalty

return scores

def reset(self):
"""每次新的生成前重置状态"""
self.state = LatexState()

遍历整个词表可还行,试一下果然慢得要死。而且模型开始无休止地吐 token ,猜测是 EOS token 也被 mask 了,
把 special tokens 从 mask 名单去除,推理仍未正常结束。此时比赛已经结束了,提交了返回固定样本的版本。

再检查发现 EOS token 后出现的都是 pad token,应该是批量推理的其他样本尚未结束生成。进而发现上面版本的 batchsize 错误,input_ids[0, -1] 和 scores[0, banned_ids] 仅处理了批次中的第一个样本,继续提示 Claude,

整理上面的代码,给出 LatexState 和 LogitsProcessor 的完整版本。
1、根据 LatexState 记录的生成状态动态计算需要 mask 的 token,例如括号、命令,_ 等;
2、计算包含中文字符等非 LaTex 标准字符的 token 进行静态 mask;
3、适配批量处理的场景;

几轮提示和修改后,得到了一个勉强可用版本,不足之处在于把包含环境的复杂序列也在 token 级别处理,实际上应该从 token 流的角度考虑,但这样引入了太多复杂性,到此为止。

点击展开详情
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
class LatexConstraintProcessor(LogitsProcessor):

def __init__(
self,
tokenizer=None,
model=None,
penalty: float = -float('inf'),
verbose: bool = False,
):
# 支持从 model.name_or_path 自动加载 tokenizer
if tokenizer is None:
assert model is not None, (
"请传入 tokenizer 或 model:\n"
" LatexConstraintProcessor(tokenizer=tokenizer)\n"
" LatexConstraintProcessor(model=model)"
)
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
model.name_or_path, trust_remote_code=True
)

self.tokenizer = tokenizer
self.penalty = penalty
self.verbose = verbose

# EOS / 特殊 token
self._eos_token_id = tokenizer.eos_token_id
special_ids = set(tokenizer.all_special_ids)
for attr in ('eos_token_id', 'bos_token_id', 'pad_token_id', 'unk_token_id'):
val = getattr(tokenizer, attr, None)
if val is not None:
special_ids.add(val)
self._special_token_ids = list(special_ids)

# 预计算词表
self._vocab_decoded, self._static_banned = self._precompute_vocab(tokenizer, special_ids)

# 预计算动态检查索引
self._precompute_token_indices()

# batch 状态(延迟初始化)
self.states: list[LatexState] = []

print(
f"[LatexConstraintProcessor] "
f"特殊token: {len(special_ids)}个 "
f"静态过滤: {len(self._static_banned)}个 "
f"动态检查词表: {len(self._vocab_decoded)}个"
)

# ── 词表预计算 ──────────────────────────────────────────────

def _is_legal_latex_char(self, text: str) -> bool:
"""判断文本是否只含合法 LaTeX 字符"""
LEGAL_PATTERN = re.compile(
r'^['
r'a-zA-Z0-9' # 英文字母和数字
r'\s' # 空白字符(空格、换行、tab)
r'\\{}()\[\]' # LaTeX 核心符号
r'\+\-\*/=<>!&|^~' # 运算符
r'_\^' # 上下标
r'.,;:\'"`' # 标点
r'#%' # 其他常用 r'@#%',按数据集改为 r'#%'
r']+$'
)
ILLEGAL_RANGES = [
(0x4E00, 0x9FFF), # 中文基本汉字
(0x3400, 0x4DBF), # 中文扩展A
(0x20000, 0x2A6DF), # 中文扩展B
(0xFF00, 0xFFEF), # 全角字符
(0x3000, 0x303F), # 中文标点
(0x0080, 0x009F), # 不可见控制符
(0xD800, 0xDFFF), # UTF-16 代理区
(0xE000, 0xF8FF), # 私有区
]
for char in text:
cp = ord(char)
for start, end in ILLEGAL_RANGES:
if start <= cp <= end:
return False
return bool(LEGAL_PATTERN.match(text))

def _precompute_vocab(
self,
tokenizer,
special_ids: set,
) -> tuple[dict[int, str], torch.Tensor]:
"""
返回:
vocab_decoded: 合法 token 的 id → 文本(用于动态检查)
static_banned: 非法字符 token 的 id tensor(静态屏蔽)
"""
vocab = tokenizer.get_vocab()
static_banned = []
vocab_decoded = {}

for token, idx in vocab.items():
# 特殊 token 直接保留,内容置空(不参与任何规则检查)
if idx in special_ids:
vocab_decoded[idx] = ""
continue
try:
text = tokenizer.convert_tokens_to_string([token])
except Exception:
vocab_decoded[idx] = ""
continue

if not self._is_legal_latex_char(text):
static_banned.append(idx)
else:
vocab_decoded[idx] = text

return vocab_decoded, torch.tensor(static_banned, dtype=torch.long)

def _precompute_token_indices(self):
"""
按字符内容预分组,供动态屏蔽 O(1) 查找:
_tokens_contain[ch] → 以字符 ch 开头的所有 token id 列表
_end_env_ids[env] → 包含 \\end{env} 的 token id 列表
_cmd_token_ids → 包含未知命令的 token id 列表(静态)
"""
# 关键字符分组
KEY_CHARS = ('^', '_', '&', ']', ')', '}')
self._tokens_startwith: dict[str, list[int]] = {ch: [] for ch in KEY_CHARS}

# \end{env} 分组
self._end_env_ids: dict[str, list[int]] = {}

# 非法命令(静态,加入 static_banned)
invalid_cmd_ids = []

for token_id, text in self._vocab_decoded.items():
if not text:
continue

# 按包含字符分组
for ch in KEY_CHARS:
if text.lstrip().startswith(ch):
self._tokens_startwith[ch].append(token_id)

# \end{env} 分组
m = re.search(r'\\end\{(\w+)\}', text)
if m:
env = m.group(1)
self._end_env_ids.setdefault(env, []).append(token_id)

# 非法命令检查
commands = re.findall(r'\\([a-zA-Z]+)', text)
if any(cmd not in VALID_COMMANDS for cmd in commands):
invalid_cmd_ids.append(token_id)

# 非法命令并入静态禁止
if invalid_cmd_ids:
extra = torch.tensor(invalid_cmd_ids, dtype=torch.long)
self._static_banned = torch.cat([self._static_banned, extra]).unique()

# ── 动态屏蔽 ────────────────────────────────────────────────

def _get_context_banned(self, state: LatexState) -> list[int]:
"""根据当前状态返回需要动态屏蔽的 token id"""
banned = []

# 规则1:^ 或 _ 之后尚未完整参数,禁止再出现 ^ 和 _
if not state.script_has_arg:
banned.extend(self._tokens_startwith['^'])
banned.extend(self._tokens_startwith['_'])

# 规则2:刚完成一个脚本参数(如 a^1 或 a^{12} 末尾),
# 则禁止紧接着生成同类脚本(防止 a^1^ 或 a_1_2)
# 注意:不禁止另一类,a^1_2 是合法的
elif state.script_just_completed and state.last_script_char:
banned.extend(self._tokens_startwith[state.last_script_char])

# 规则3:不在矩阵环境,禁止 &
if not state.in_matrix_env:
banned.extend(self._tokens_startwith['&'])

# 规则4:括号栈顶不匹配,禁止对应非法闭合括号
if state.bracket_stack:
top = state.bracket_stack[-1]
for illegal_close in BRACKET_MISMATCH.get(top, ()):
banned.extend(self._tokens_startwith[illegal_close])
else:
# 禁止所有闭合括号
for illegal_close in CLOSE_BRACKETS:
banned.extend(self._tokens_startwith[illegal_close])

# 规则5:\end{env} 必须与栈顶匹配
if state.env_stack:
top_env = state.env_stack[-1]
for env, ids in self._end_env_ids.items():
if env != top_env:
banned.extend(ids)

return banned

# ── __call__ ────────────────────────────────────────────────

def __call__(
self,
input_ids: torch.LongTensor, # [batch, seq_len]
scores: torch.FloatTensor, # [batch, vocab_size]
) -> torch.FloatTensor:

batch_size = input_ids.shape[0]
device = scores.device

# 延迟初始化 / batch size 变化时重置状态
if len(self.states) != batch_size:
self.states = [LatexState() for _ in range(batch_size)]

# 设备对齐(只在首次或设备变化时迁移)
if self._static_banned.device != device:
self._static_banned = self._static_banned.to(device)

for i in range(batch_size):
# 已生成 EOS 的样本跳过
if self._eos_token_id is not None:
if (input_ids[i] == self._eos_token_id).any():
continue

# 更新第 i 个样本的状态
if input_ids.shape[1] > 0:
last_id = input_ids[i, -1].item()
last_text = self._vocab_decoded.get(last_id, "")
self.states[i].update(last_text)

if self.verbose:
print(
f"[batch={i}] last_token={last_text!r} "
f"env_stack={self.states[i].env_stack} "
f"bracket_stack={self.states[i].bracket_stack} "
f"in_matrix={self.states[i].in_matrix_env}"
)

# 静态屏蔽
scores[i, self._static_banned] = self.penalty

# 动态屏蔽
context_banned = self._get_context_banned(self.states[i])
if context_banned:
scores[i, context_banned] = self.penalty

# 强制恢复所有特殊 token(确保 EOS 不被屏蔽)
for sp_id in self._special_token_ids:
scores[:, sp_id] = scores[:, sp_id].clamp(min=0)

return scores

def reset(self):
"""每次新的 generate 调用前重置所有状态"""
self.states = []

效果
实现的简易约束解码效果挺抽象,公式生成变得合法了,但正确性下降了。
随机 100 个合成公式的测试结果:评分大幅下降,渲染失败个数从 3 个减少到 1 个,换了个子集甚至得到更多的渲染失败。
问题可能在于单个 token 级别的约束实际破坏了模型的 LaTex 表示能力,表示能力是建立在语法单元上的,
而它们由多个 token 组成(或者相反一个 token 超过了一个基本单元),因此单个 token 的约束和语法单元并不对齐

image

插曲

跑评测时发现某些情况会报错,图片成功渲染了相似度却通通是 0,细察发现图片读取部分有问题,
用 dvipng 把 latex 编译结果转换为 png 格式时,可能会把 RGBA 优化成 colormap,导致 cv2.imread 读取有误,应该也影响 hash 。。
复现 bug 提了 issue 和 PR,在微信群里看到好像有 B 榜分数有震荡,不知道是不是这个原因
同一张图片,因读取问题产生的相似度分数差异:

image image

经验、教训与感受

想了下,比赛过程有些缺乏章法,接近炼丹了。

  • 应该尽早构造验证集和本地评测流程,检查错误样本。
    这样就能更早地推进约束解码/后处理的方案,以及通过本地评估指导数据合成。
  • 科学炼丹:应该排实验计划,对每个有希望的方向进行试训。
    比如对全量微调探索太少,可以尝试分阶段微调和清洗后的合成数据集微调。
  • 应该尽早投入比赛。。好吧其实没邀请对象没有那么多算力。
  • 拥抱 AI 力度不够(还是舍不得买 Claude,以及完全放手给 Coding Agent)
    上一期最后没抢到算力,准备的最后一版数据没用上,直接被挤出前 20 了。最后观摩一个公开出来的前三的方案,AI 代码 AI 报告看得我两眼一黑。毫无疑问这次我又将与 Claude Opus 作战,本着打不过就加入的原则在 review 和验证中大量使用 AI 代码,但效率有限。这次我将把奖金上交智谱。(好吧有 Codex 已经满足了)
  • 比第一次有进步
    第一次完全没理解模型仓库 GitHub 格式和 HF 格式的差异,这次终于有所了解,对 LLM/VLM 处理输入到输出的过程也更清晰了。尝试 ms-swift 框架的同时第一次跑起来 RL 代码(虽然是用 trl),强化学习并非高高在上,遥不可及!
  • 把 LLM 的知识蒸馏给自己
    接受 AI,拥抱 AI,努力达到 AI 的边界,但不把 AI 的能力误认为自己的能力。
    可是我们做到一件事,必须要掌握全部细节吗,我此刻产生了怀疑。领导者的能力一定要覆盖或超过下属吗?知其然知其所以然,可学也无涯。只能寄希望于“本体越强,替身使者越强”,否则只需要强化人机同步率了,健身还有什么用呢?也许我还停留在“学习技术”提供的多巴胺舒适区吧。
  • 对于可靠系统 vibe coding requires review.

仓库

除了最初版的数据生成器丢失了,本次比赛产生的各种未经整理的代码放在了 kafmws/VLM-formula-recognition-dataset.git

未尽探索

增加视觉 token 数量、增大图像输入分辨率,也是显而易见的方向。但担心会破坏预训练模型的视觉特征提取能力,没尝试这些脑海里一闪而过的念头,还是有些可惜。RL 和约束解码也尝试太晚,没有成功应用到提交中。
下次再来!

2026/03/02