If the broadcasting is what's bothering you, you could use a nn.Flatten
to do it:
>>> m = nn.Sequential(
... nn.Flatten(),
... nn.Linear(24*768, 768))
>>> x = torch.rand(1, 24, 768)
>>> m(x).shape
torch.Size([1, 768])
If you really want the extra dimension you can unsqueeze the tensor on axis=1
:
>>> m(x).unsqueeze(1).shape
torch.Size([1, 1, 768])
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…