forked from torch/nn
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLayerNormalization.lua
More file actions
27 lines (24 loc) · 895 Bytes
/
LayerNormalization.lua
File metadata and controls
27 lines (24 loc) · 895 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
-- Reference: https://arxiv.org/pdf/1607.06450.pdf (Section 3)
local LayerNormalization, parent = torch.class('nn.LayerNormalization', 'nn.Sequential')
function LayerNormalization:__init(nOutput, bias, eps, affine)
parent.__init(self)
eps = eps or 1e-10
affine = (affine == nil) and true or affine
bias = bias or 0
self:add(nn.ConcatTable()
:add(nn.Identity())
:add(nn.Sequential()
:add(nn.Mean(1, 1))
:add(nn.Replicate(nOutput,1,1))))
:add(nn.CSubTable())
:add(nn.Normalize(2, eps))
:add(nn.MulConstant(torch.sqrt(nOutput)))
if affine then
local biasTransform = nn.Add(nOutput, false)
biasTransform.bias:fill(bias)
local gainTransform = nn.CMul(nOutput)
gainTransform.weight:fill(1.)
self:add(gainTransform)
self:add(biasTransform)
end
end