๐ก
imageNet์ dataset์์ ํฐ ๋ฏธ๋๋ฐฐ์น๋ ์ต์ ํ๋ฅผ ์ํค๋๋ฐ ์ด๋ ต๋ค๋๊ฒ์ ๋ณด์ฌ์ค(ํ์ง๋ง ์ด ๋ฌธ์ ๊ฐ ํด๊ฒฐ๋๋ฉด ํ๋ จ๋ชจ๋ธ์ ์ข์ ์ผ๋ฐํ๋ฅผ ๊ฐ์ง). ์ด ๋ฌธ์ ๋ฅผ ์ํด ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ์ ํจ์๋ก ํ์ต์๋๋ฅผ ์กฐ์ ํ๊ธฐ ์ํ ์ด๋งค๊ฐ๋ณ์ ์๋ ์ ํ ํ์ฅ ๊ท์น์ ์ฑํํจ. ๋ํ ํ๋ จ ์ด๊ธฐ์ ์ต์ ํ ๋ฌธ์ ๋ฅผ ๊ทน๋ณตํ๋ ์๋ก์ด ์๋ฐ์
๋ฐฉ์ ๊ฐ๋ฐ.
This paperโs goal
- ๋ถ์ฐ์ ๋น๋๊ธฐ์ SGD(stochastic gradient descent)๊ฐ ๋๊ท๋ชจ ํธ๋ ์ด๋์ ์ ํฉํ๋ค๋ ๊ฒ์ ์ฆ๋ช ํ๊ณ ์ค์ฉ์ ์ธ ๊ฐ์ด๋ ์ ๋ฌํ๊ธฐ
ํฐ ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ๋ฅผ ๋ค๋ฃจ๊ธฐ ์ํด ํ์ํ ๊ฒ
- hyper-parameter-free linear scaling rule๋ฅผ ์ด์ฉํ์ฌ learning rate ์กฐ์ ํ๊ธฐ
- ์ด ๊ฐ์ด๋๋ผ์ธ์ earlier worker์์ ์ค๋ฆฝ๋จ
- ๊ฒฝํ์ ํ๊ณ๋ ์ ์ดํด๋์ง ์์ผ๋ฉฐ, ๋น๊ณต์์ ์ด๊ธฐ ๋๋ฌธ์ research community์ ์ ์๋ ค์ง์ง ์์
- ์ด ๋ฒ์น์ ์ฑ๊ณต์ ์ผ๋ก ์ ์ฉํ๊ธฐ ์ํด ์๋ก์ด ์ค๋น์ ๋ต ์ ์
- ์ค๋น์ ๋ต = ์ด๊ธฐ ์ต์ ํ๋ฌธ์ ๋ฅผ ๊ทน๋ณตํ๊ธฐ ์ํด ๋ฎ์ learning rate ์ฌ์ฉํ๊ธฐ
- ์์ค L(w)์
L(w)=1โฃXโฃโxโXl(x,w)L(w) = \frac{1}{|X|}\sum_{x\in X}l(x, w)
- w : ๋คํธ์ํฌ์ ๊ฐ์ค์น(weight of a network)
- X : ๋ ์ด๋ธ์ด ์ง์ ๋ training set
- l(x, w) : ์ํxโXx \in X๊ณผ ๋ผ๋ฒจ y์์ ๊ณ์ฐ๋ ์์ค
- l : ๋ถ๋ฅ์์ค(๊ต์ฐจ์ํธ๋กํผ)์ w์๋ํ ์ ๊ท์์ค์ ํฉ
- Minibatch stochastic gradient decent๋ ๋ฏธ๋๋ฐฐ์น์์ ์๋ํจ. ๋ณดํต ๊ฐ๋จํ๊ฒ SGD๋ผ๊ณ ํ๋ฉฐ, ๋ค์ ์
๋ฐ์ดํธ๋ฅผ ์ํํจ
wt+1=wtโฮท1nโxโBโฝ(x,wt)w_{t+1} = w_t - \eta\frac{1}{n}\sum_{x \in B} \bigtriangledown (x, w_t)
- BB : X์์ ์ํ๋ง๋ ๋ฏธ๋๋ฐฐ์น์
- n=โฃBโฃn= |B | : ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ
- ฮท\eta : learning rate
- t : iteration index
- ์ด ๋ ผ๋ฌธ์์๋ ์ค์ ๋ก momentum SGD๋ฅผ ์ฌ์ฉํจ.
2.1 Learning rates for Large Minibatches
- ๋ชฉํ
- ์์ ๋ฏธ๋๋ฐฐ์น ๋์ ํฐ ๋ฏธ๋๋ฐฐ์น๋ฅผ ์ฐ๋ฉด์ training๊ณผ ์ผ๋ฐํ ์ ํ๋ ์ ์งํ๊ธฐโ worker๋ณ ์์ ๋์ ์ค์ด๊ฑฐ๋ ๋ชจ๋ธ ์ ํ๋๋ฅผ ์ ํ์ํค์ง ์๊ณ ๊ฐ๋จํ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ๋ฌ worker๋ก ํ์ฅํ ์ ์๊ธฐ ๋๋ฌธ์ ๋ถ์ฐํ์ต์์ ํนํ ์ค์!
- (worker์ GPU๋ฅผ ๊ฐ์ ์๋ฏธ๋ก ์ฌ์ฉ)
- learning rate scaling rule์ด ๊ด๋ฒ์ํ ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ์ ๋๋ผ์ธ์ ๋๋ก ํจ๊ณผ์ ์linear scaling rule : ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ์ k๋ฅผ ๊ณฑํ๋ฉด, ํ์ต๋ฅ ์ k๋ฅผ ๊ณฑํ๊ธฐ
- ๋ค๋ฅธ hyper-parameter(weight decay)๋ ๋ฐ๊พธ์ง ์๊ธฐ
- linear scaling rule์ ์์ ๋ฏธ๋๋ฐฐ์น์ ํฐ ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ ์ ํ๋๋ฅผ ์ผ์น์ํค๋๋ฐ ๋์์ ์ค. ๋, training curve๋ ๋งค์น ์ํดโ ์คํ์ ์ ๋น ๋ฅด๊ฒ ๋น๊ตํ๊ณ ๋๋ฒ๊น ํ ์ ์๊ฒ๋จ
- interpretation
- leaner scaling rule๊ณผ ์ ํจ๊ณผ์ ์ธ์ง ์ค๋ช
ํ๊ฒ ์โ iteration์ด t, weight๊ฐ w์ธ ๋คํธ์ํฌ,0โคj<k0 \leq j <k์ ๋ํด k๊ฐ์ ๋ฏธ๋๋ฐฐ์นBjB_j์ ์ํ์ค๋ฅผ ๊ณ ๋ คํด์ผํจ
- k SGD iteration(์์ ๋ฏธ๋๋ฐฐ์นBjB_j, learning rate๊ฐฮท\eta)์ single iteration(์ฌ์ด์ฆ๊ฐ kn ์ธ ํฐ ๋ฏธ๋๋งค์นโชjBj\cup_j B_j์ ํ์ต๋ฅ ฮท^\hat{\eta})์ ์คํ ํจ๊ณผ๋ฅผ ๋น๊ต
- learning rate๊ฐฮท\eta, ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ๊ฐ n์ธ SGD์ k iteration
wt+k=wtโฮท1nโj<kโxโBjโฝl(x,wt+j)โโโโ(3)w_{t+k} = w_t -\eta \frac{1}{n} \sum_{j<k}\sum_{x \in B_j} \bigtriangledown l (x, w_{t+j})---- (3)
- learning rate๊ฐฮท^\hat{\eta}, ์ฌ์ด์ฆ๊ฐ kn ํฐ ๋ฏธ๋๋ฐฐ์นโชjBj\cup _j B_j๋ฅผ ์ฌ์ฉํ ๋จ์ผ ๋จ๊ณ๋ฅผ ์ฌ์ฉํ๋ฉด
w^t+1=wtโฮท^1knโj<kโxโBjโฝl(x,wt)โโโโ(4)\hat{w}_{t+1} = w_t - \hat{\eta}\frac{1}{kn}\sum_{j<k}\sum_{x \in B_j}\bigtriangledown l (x, w_t)---- (4) - learning rate๊ฐฮท\eta, ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ๊ฐ n์ธ SGD์ k iteration
- ์์ํ๊ฒ์ฒ๋ผ ์ ๋ฐ์ดํธ ๊ฐ์ ์๋ก ๋ค๋ฅด๋ฉฐ,w^t+1=wt+k\hat{w}_{t+1} = w_{t+k}์ผ ํ๋ฅ ์ ์์.
- ํ์ง๋ง, j<k์ ๋ํดโฝl(x,wt)โโฝl(x,wt+j)\bigtriangledown l (x, w_t) \approx \bigtriangledown l(x, w_{t+j})๋ ๊ฐ์ ํ ์ ์๋ ๊ฒฝ์ฐฮท^=kฮท\hat{\eta} = k\eta๋ก ์ค์ ํ ๋w^t+1โwt+k\hat{w}_{t+1} \approx w_{t+k}๊ฐ ์์ฑ๋๊ณ , ์์ ๋ฏธ๋๋ฐฐ์น SGD์ ํฐ ๋ฏธ๋๋ฐฐ์น SGD๋ ์๋ก ์ ์ฌํจ.
- ๋ํ, ๊ฐ๋ ฅํ ๊ฐ์ ์์๋ ๋ถ๊ตฌํ๊ณ ์ด๊ฒ์ด ์ฌ์ค์ด๋ผ๋ฉด,ฮท^=kฮท\hat{\eta} = k\eta๋ฅผ ์ค์ ํ ๊ฒฝ์ฐ์๋ง ๋ ์ ๋ฐ์ดํธ๊ฐ ์ ์ฌํ๋ค๊ณ ๊ฐ์กฐํจ.
- k SGD iteration(์์ ๋ฏธ๋๋ฐฐ์นBjB_j, learning rate๊ฐฮท\eta)์ single iteration(์ฌ์ด์ฆ๊ฐ kn ์ธ ํฐ ๋ฏธ๋๋งค์นโชjBj\cup_j B_j์ ํ์ต๋ฅ ฮท^\hat{\eta})์ ์คํ ํจ๊ณผ๋ฅผ ๋น๊ต
- ์ ํด์์ linear scaling rule์ด ์ ์ฉ๋๊ธฐ๋ฅผ ๋ฐ๋ผ๋ ํ๊ฐ์ง ๊ฒฝ์ฐ์ ๋ํ ์ง๊ด์ ์ ๊ณตํจ.ฮท^=kฮท\hat{\eta} = k \eta(๋ฐ ์ค๋น)์ธ ์คํ์์ ์์ ๋ฏธ๋๋ฐฐ์น SGD์ ํฐ ๋ฏธ๋๋ฐฐ์น SGD๋ ๋ชจ๋ธ์์์ ๊ฐ์ ๋ง์ง๋ง ์ ํ๋๋ฟ๋ง ์๋๋ผ training curve๋ํ ๊ฝค ๋งค์น๋จ. ์คํ๊ฒฐ๊ณผ๋ ์ ๊ทผ์ฌ์น๊ฐ ๋๊ท๋ชจ ์ค์ ๋ฐ์ดํฐ์์ ์ ํจํ ์ ์๋ค๊ณ ์ ์ํ๋ค.
- ๊ทธ๋ฌ๋ ์กฐ๊ฑดโฝl(s,wt)โโฝl(x,wt+j)\bigtriangledown l (s, w_t) \approx \bigtriangledown l (x, w_{t+j})๊ฐ ์ ์ง๋์ง ์๋ ๋๊ฐ์ง ๊ฒฝ์ฐ๊ฐ ์์
- ๋คํธ์ํฌ๊ฐ ๋น ๋ฅด๊ฒ ๋ฐ๋๋์ ์ด๊ธฐ ํ๋ จ(initial training)โ 2.2์ ์ค๋น๋จ๊ณ๋ฐฉ๋ฒ์ ์ฌ์ฉํด ํด๊ฒฐ
- ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ๊ฐ ๋ฌดํ์ ํ์ฅ๋ ์ ์์.
- ๊ฒฐ๊ณผ๋ ๋ค์ํ ๋ฒ์์ ํฌ๊ธฐ์์ ์์ ์ ์ด์ง๋ง, ํน์ ์ง์ ์ ๋์ผ๋ฉด ์ ํ๋๊ฐ ๋น ๋ฅด๊ฒ ๋จ์ด์ง
- ์ด ์ง์ ์ imageNet ์คํ์์ 8k๋ณด๋ค ํผ
- leaner scaling rule๊ณผ ์ ํจ๊ณผ์ ์ธ์ง ์ค๋ช
ํ๊ฒ ์โ iteration์ด t, weight๊ฐ w์ธ ๋คํธ์ํฌ,0โคj<k0 \leq j <k์ ๋ํด k๊ฐ์ ๋ฏธ๋๋ฐฐ์นBjB_j์ ์ํ์ค๋ฅผ ๊ณ ๋ คํด์ผํจ
- Discussion
- ์๊ธฐ linear scaling rule์ Krizhevsky์ ์ํด ์ฑํ๋จ. ํ์ง๋ง Krizhevsky๋ ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ๊ฐ 128์์ 1024๋ก ์ฆ๊ฐํ ๋ ์๋ฌ๊ฐ 1%๊ฐ ์ฆ๊ฐํ๋ค๊ณ ๋ณด๊ณ ํจ. ๋ฐ๋ฉด์ ์ฐ๋ฆฌ๋ ํจ์ฌ ๋ ๊ด๋ฒ์ํ ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ ๋ฒ์์์ ์ด๋ป๊ฒ ์ ํ๋๋ฅผ ์ ์งํ๋์ง ๋ณด์ฌ์ค.
- Chen์ ์๋ง์ ๋ถ์ฐ SGD ๋ณํ์ ๋น๊ต๋ฅผ ์ ์ํจ. ๊ทธ๋ค์ linear scaling rule์ ์ฌ์ฉํ์ง๋ง, ๋ฏธ๋๋ฐฐ์น ๊ธฐ์ค์ ์ ์ค์ ํ์ง ์์.
- Li๋ ์๋ ดํ ์ ํ๋ ์์ค ์์ด ๋ฏธ๋๋ฐฐ์น๊ฐ 5210๊น์ง์ธ ๋ถ์ฐ imageNet ํ๋ จ์ ๋ณด์ฌ์ค. ํ์ง๋ง ์ฐ๋ฆฌ๊ฐ ํต์ฌ์ ์ผ๋ก ๊ธฐ์ ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ๋ก ํ์ต๋ฅ ์กฐ์ ์ ์ํด hyper-parameter search rule์ ๋ณด์ฌ์ฃผ์ง ๋ชปํจ.
- ์ต๊ทผ ์ฐ๊ตฌ์์, Bottou ๋ฑ [4] (ยง4.2)์ ๋ฏธ๋๋ฐฐ์น์ ์ด๋ก ์ ์ฅ๋จ์ ์ ๊ฒํ ํ๊ณ ์ ํ ์ค์ผ์ผ๋ง ๊ท์น์ ๋ฐ๋ผ solver๊ฐ ๋ณธ ์์ ์์ ํจ์๋ก ๋์ผํ ํ๋ จ ๊ณก์ ์ ๋ฐ๋ฅด๋๊ฒ์ ๋ณด์ฌ์ค. ํ์ต๋ฅ ์ ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ์ ๋ ๋ฆฝ์ ์ธ ์ต๋ ์๋๋ฅผ ์ด๊ณผํด์๋ ์ ๋๋ฉฐ(๋ฐ๋ผ์ ์์ ์ด ์ ๋นํ๋จ), ์ ๋ก ์๋ ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ๋ก ์ด๋ฌํ ์ด๋ก ์ ์คํ์ ์ผ๋ก ๊ฒ์ฆํจ.
2.2 warmup
- ๋
ผ์ํ ๋๋ก, ๋๊ท๋ชจ ๋ฏธ๋๋ฐฐ์น(์: 8k)์ ๊ฒฝ์ฐ, ์ ๊ฒฝ๋ง์ด ๋น ๋ฅด๊ฒ ๋ณํ ๋ ์ ํ ์ค์ผ์ผ๋ง ๊ท์น์ด ๋ถ๊ดด๋จ. ์ด๋ ํ๋ จ ์ด๊ธฐ์ ํํ ๋ฐ์ํจโ ์ด ๋ฌธ์ ๋ฅผ ์ ์ ํ ์ค๊ณ๋ ์์
[16]์ ์ํด ์ํ์ํฌ ์ ์์.
- ํ๋ จ ์์ ์์ ๋ ๊ณต๊ฒฉ์ ์ธ ํ์ต๋ฅ ์ ์ฌ์ฉํ๋ ์ ๋ต
- Constant warmup
- [16]์์ ์ ์๋ ์์ ์ ๋ต์ ํ๋ จ์ ์ฒ์ ๋ช ์ํฌํฌ ๋์ ๋ฎ์ ๊ณ ์ ํ์ต๋ฅ ์ ์ฌ์ฉ
- ์ฌ์ ํ๋ จ๋ ๋ ์ด์ด๋ฅผ ์๋ก ์ด๊ธฐํ๋ ๋ ์ด์ด์ ํจ๊ป ์ธ๋ฐํ๊ฒ ์กฐ์ ํ๋ ๊ฐ์ฒด ๊ฒ์ถ ๋ฐ ์ธ๊ทธ๋ฉํ ์ด์ ๋ฐฉ๋ฒ์ Constant warmup์ด ์ ์ฉ
- ๋๊ท๋ชจ ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ kn์ผ๋ก ํ ImageNet ์คํ์์, ์ฒ์ 5 epochs๋์ ํ์ต๋ฅ ฮท๋ก ํ๋ จํจ. ๊ทธ ํ ๋ชฉํ ํ์ต๋ฅ ฮท^=kฮท\hat{\eta} = k\eta๋ก ๋์๊ฐ๋ ค๊ณ ํ์. ๊ทธ๋ฌ๋ ํฐ k๊ฐ ์ฃผ์ด์ง ๊ฒฝ์ฐ, ์ด Constant warmup๋ง์ผ๋ก๋ ์ต์ ํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ์ ์ถฉ๋ถํ์ง ์๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์์ผ๋ฉฐ, ๋ฎ์ ํ์ต๋ฅ ์์ ๋จ๊ณ์์์ ์ ํ์ ํ๋ จ ์ค๋ฅ๋ฅผ ๊ธ์ฆ์ํฌ ์ ์์โ ์ด๋ก ์ธํด ์ฐ๋ฆฌ๋ ๋ค์๊ณผ ๊ฐ์ ์ ์ง์ ์์ (gradual warmup)์ ์ ์
- gradual warmup
- ํ์ต๋ฅ ์ ์์ ๊ฐ์์๋ถํฐ ํฐ ๊ฐ์ผ๋ก ์ ์ง์ ์ผ๋ก ์ฆ๊ฐ์ํค๋ ๋์์ ์ธ warmup
- ํ์ต๋ฅ ์ ๊ฐ์์ค๋ฌ์ด ์ฆ๊ฐ๋ฅผ ํผํ๋ฉฐ, ํ๋ จ ์ด๊ธฐ์ ๊ฑด๊ฐํ ์๋ ด(healthy convergence)์ ํ์ฉํจ.
- ์ค์ ๋ก ๋๊ท๋ชจ ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ kn์ผ๋ก ์์ํ์ฌ ํ์ต๋ฅ ์ฮท\eta์์ ์์ํ์ฌ ๊ฐ ๋ฐ๋ณต์์ ์์๋๋งํผ ์ฆ๊ฐ์์ผ 5 epochs ํ์ฮท^=kฮท\hat{\eta} = k\eta์ ๋๋ฌํ๋๋ก ํจํฉ๋๋ค (์ ํํ warmup ๊ธฐ๊ฐ์ ๋ํ ๊ฒฐ๊ณผ๋ ๊ฒฌ๊ณ ํจ). warmup ํ์๋ ์๋์ ํ์ต๋ฅ ๋ก ๋์๊ฐ.
2.3 Batch Normalization with large minibatches
- Batch Normalization : ๋ฏธ๋๋ฐฐ์น ์ฐจ์์ ๋ฐ๋ผ ํต๊ณ๋์ ๊ณ์ฐ
- ๊ฐ ์ํ์ ์์ค์ ๋ ๋ฆฝ์ฑ์ ๊นจ๋จ๋ฆผ
- ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ์ ๋ณํ๋ ์ต์ ํ๋๋ ์์ค ํจ์์ ๊ธฐ๋ณธ ์ ์๋ฅผ ๋ณ๊ฒฝ
- ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ๋ณ๊ฒฝํ ๋ ์์ค ํจ์๋ฅผ ๋ณด์กดํ๊ธฐ ์ํ shortcut
- ํต์ ์ค๋ฒํค๋๋ฅผ ํผํ๊ธฐ ์ํ ์ค์ฉ์ ์ธ ๊ณ ๋ ค
- ์์ค ํจ์๋ฅผ ๋ณด์กดํ๋ ๋ฐ ํ์์
- BN์ด ์ํ๋๊ณ ํ์ฑํ๊ฐ ์ํ ๊ฐ์ ๊ณ์ฐ๋ ๋๋ ์ด๋ฌํ ๊ฐ์ ์ด ์ฑ๋ฆฝํ์ง ์์
- ๊ฐ์ : per-sample loss์ธl(x,w)l(x, w)์ด ๋ค๋ฅธ ๋ชจ๋ ์ํ๊ณผ ๋ ๋ฆฝ์ ์
- ํฌ๊ธฐ๊ฐ n์ธ ๋จ์ผ ๋ฏธ๋๋ฐฐ์น B์ ์์ค์L(B,w)=1nโxโBlB(x,w)L(B, w) = \frac{1}{n}\sum_{x \in B}l_B(x, w)๋ก ๋ํ
- BN์ด ์ ์ฉ๋ ๊ฒฝ์ฐ, ํ๋ จ ์ธํธ๋ ์๋ ํ๋ จ ์ธํธ X์์ ์ถ์ถ๋ ํฌ๊ธฐ๊ฐ n์ธ ๋ชจ๋ ๊ตฌ๋ณ๋๋ ๋ถ๋ถ์งํฉ์ ํฌํจ โXnX^n
- training loss L(w)
L(w)=1โฃXnโฃโBโXnL(B,w)L(w) = \frac{1}{|X^n|}\sum_{B \in X^n}L(B, w) \ \- ๋ง์ฝ B๋ฅผXnX^n์ '๋จ์ผ ์ํ(single sample)'๋ก ๋ณด๋ฉด, ๊ฐ ๋จ์ผ ์ํ B์ ์์ค์ด ๋ ๋ฆฝ์ ์ผ๋ก ๊ณ์ฐ๋จโ ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ n์ด BN ํต๊ณ๋ฅผ ๊ณ์ฐํ๋ ๋ฐ ์ค์ํ ๊ตฌ์ฑ ์์์์ ์ ์ํด์ผ ํจ
- ๊ฐ worker์ ๋ฏธ๋๋ฐฐ์น ์ํ ํฌ๊ธฐ n์ด ๋ณ๊ฒฝ๋๋ฉด ์ต์ ํ๋๋ ๊ธฐ๋ณธ ์์ค ํจ์ L์ด ๋ณ๊ฒฝ๋๊ฒ ๋จ
- BN์ด ๋ค๋ฅธ n์ผ๋ก ๊ณ์ฐํ ํ๊ท /๋ถ์ฐ ํต๊ณ๋ ์๋ก ๋ค๋ฅธ ์์ค์ ๋ฌด์์ ๋ณ๋์ ๋ํ๋
- ๋ถ์ฐ๋(๊ทธ๋ฆฌ๊ณ ๋ฉํฐ-GPU) ํ๋ จ
- ๋ง์ผ worker๋น ์ํ ํฌ๊ธฐ n์ด ๊ณ ์ ๋์ด ์๊ณ ์ด ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ๊ฐ kn์ผ ๋
- ๊ฐ ์ํBjB_j๊ฐXnX^n์์ ๋ ๋ฆฝ์ ์ผ๋ก ์ ํ๋ k๊ฐ์ ์ํ ๋ฏธ๋๋ฐฐ์น๋ก ๋ณผ ์ ์์.
- ์ด ๊ด์ ์์ BN ์ค์ ์์, k๊ฐ์ ๋ฏธ๋๋ฐฐ์นBjB_j๋ฅผ ๋ณธ ํ, (3)๊ณผ (4)๋ ๋ค์๊ณผ ๊ฐ์ด ๊ณ์ฐ๋จ
wt+k=wtโฮทโj<kโฝL(Bj,wt+j)w_{t+k} = w_t - \eta \sum_{j<k} \bigtriangledown L (B_j, w_{t+j})wt+k^=wtโฮท^1kโj<kโฝL(Bj,wt+j)\hat{w_{t+k}} = w_t - \hat{\eta} \frac{1}{k} \sum_{j<k} \bigtriangledown L (B_j, w_{t+j})
- ๋ณธ ์ฐ๊ตฌ์์๋ฮท^=kฮท\hat{\eta} = k\eta๋ก ์ค์ ํ๊ณ worker ์ k๋ฅผ ๋ณ๊ฒฝํ ๋ worker๋น ์ํ ํฌ๊ธฐ n์ ์ผ์ ํ๊ฒ ์ ์งํจ. ๊ทธ๋ฆฌ๊ณ ๋ค์ํ ๋ฐ์ดํฐ์ ๊ณผ ๋คํธ์ํฌ์์ ์ ๋์ํ n = 32๋ฅผ ์ฌ์ฉ
- ๋ง์ฝ n์ด ์กฐ์ ๋๋ค๋ฉด, ์ด๋ BN์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ๊ฐ์ฃผ๋์ด์ผ ํ๋ฉฐ ๋ถ์ฐ ํ๋ จ์ด ์๋.
- ๋ํ BN ํต๊ณ๋ ํต์ ์ ์ค์ด๊ธฐ ์ํด์๋ง์ด ์๋๋ผ ์ต์ ํ๋๋ ๊ธฐ๋ณธ ์์ค ํจ์๋ฅผ ๋์ผํ๊ฒ ์ ์งํ๊ธฐ ์ํด์๋ ๋ชจ๋ worker๋ฅผ ๋์์ผ๋ก ๊ณ์ฐํ๋ฉด ์๋จ
- ๋ง์ผ worker๋น ์ํ ํฌ๊ธฐ n์ด ๊ณ ์ ๋์ด ์๊ณ ์ด ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ๊ฐ kn์ผ ๋
- ๋ถ์ฐ ๊ตฌํ ์ค๋ฅ๋ค์ ํ์ดํผํ๋ผ๋ฏธํฐ์ ์ ์๋ฅผ ๋ณ๊ฒฝํ์ฌ ํ๋ จ๋๋ ๋ชจ๋ธ์ ์ค์ฐจ๊ฐ ์์๋ณด๋ค ๋์์ง๊ฒ ํ ์ ์์ผ๋ฉฐ, ์ด๋ฌํ ๋ฌธ์ ๋ค์ ๋ฐ๊ฒฌํ๊ธฐ ์ด๋ ค์ธ ์ ์์ต๋๋ค. ์๋์ ์ฃผ์ฅ๋ค์ ๋ช ํํ์ง๋ง, ๊ธฐ๋ณธ solver๋ฅผ ์ถฉ์คํ ๊ตฌํํ๊ธฐ ์ํด ๋ช ์์ ์ผ๋ก ๊ณ ๋ คํ๋ ๊ฒ์ด ์ค์
Weight decay
- ์์ค ํจ์์ L2 ์ ๊ทํ ํญ์ ๊ธฐ์ธ๊ธฐ์ ๊ฒฐ๊ณผ
- per-sample loss
l(x,w)=ฮป2โฃโฃwโฃโฃ2+ฮต(x,w)l(x, w) = \frac{\lambda }{2}||w||^2 + \varepsilon (x, w)
- ฮป2โฃโฃwโฃโฃ2\frac{\lambda}{2} ||w||^2 : ๊ฐ์ค์น์ ๋ํ ์ํ ๋ ๋ฆฝ์ ์ธ L2 ์ ๊ทํ
- ฮต(x,w)\varepsilon (x, w) : ํฌ๋ก์ค ์ํธ๋กํผ์ ์ํ ์ข ์์ ์ธ ํญ(sample-dependent term)
- SGD update
wt+1=wtโฮทฮปwtโฮท1nโxโBโฝฮต(x,wt)w_{t+1} = w_t - \eta \lambda w_t - \eta\frac{1}{n}\sum_{x \in B} \bigtriangledown \varepsilon (x, w_t)
- ์ค์ ๋ก๋ ์ผ๋ฐ์ ์ผ๋ก ์ญ์ ํ๋ฅผ ํตํด ์ํ ์ข ์์ ์ธ ํญโโฝฮต(x,wt)\sum \bigtriangledown \varepsilon (x, w_t) ๋ง ๊ณ์ฐ๋จ
- ๊ฐ์ค์น ๊ฐ์ ํญฮปwt\lambda w_t ๋ ๋ณ๋๋ก ๊ณ์ฐ๋์ดฮต(x,wt)\varepsilon (x, w_t)์ ๊ธฐ์ฌํ ๊ทธ๋ ๋์ธํธ์ ์ถ๊ฐ๋จ
- ๊ฐ์ค์น ๊ฐ์ ํญ์ด ์๋ ๊ฒฝ์ฐ, ํ์ต๋ฅ ์ ์กฐ์ ํ๋ ๋ค์ํ ๋ฐฉ๋ฒ์ด ์์
- ฮต(x,wt)\varepsilon (x, w_t)ํญ์ ์กฐ์ ํ๋ ๋ฐฉ๋ฒ์ ํฌํจ
- ๊ทธ๋ฌ๋ ์ด๋ ํญ์ ์ฑ๋ฆฝํ์ง ์์.
Momentum correction
- Momentum SGD
- ๋ฐ๋๋ผ SGD์ ๋ํ ํํ ์ฑํ๋๋ modification์ค ํ๋
- ๋ชจ๋ฉํ
SGD์ ์ฐธ์กฐ ๊ตฌํ
ut+1=mut+1nโฝl(x,wt)โโโโ(9)u_{t+1} = mu_t + \frac{1}{n} \bigtriangledown l (x, w_t)----(9)wt+1=wtโฮทut+1โโโโ(9)w_{t+1} = w_t - \eta u_{t+1}----(9)
- m : momentum decay factor
- u : update tensor
- ์ธ๊ธฐ ์๋ ๋ณํ
- ํ์ต๋ฅ ฮท๋ฅผ ์
๋ฐ์ดํธ ํ
์์ ํก์ํ๋ ๊ฒ
vt+1=mut+ฮท1nโxโBโฝl(x,wt)โโโโ(10)v_{t+1} = mu_t + \eta \frac{1}{n}\sum_{x \in B} \bigtriangledown l (x, w_t)----(10)wt+1=wtโvt+1โโโโ(10)w_{t+1} = w_t - v_{t+1}----(10)
- ๊ณ ์ ๋ฮท\eta์ ๋ํด 2๊ฐ์ง๊ฐ ๋์ผํจ
- u๊ฐ ๊ทธ๋ ์ด๋์ธํธ์๋ง ์์กดํ๊ณ ฮท\eta์๋ ๋ ๋ฆฝ์ ์ธ ๋ฐ๋ฉด, v๋ฮท\eta์ ์ฝํ ์๋ค๋ ์ ์ ์ฃผ๋ชฉํด์ผ ํจ.
- ฮท\eta๊ฐ ๋ณ๊ฒฝ๋ฌ์ ๋ (9)์ ์ฐธ์กฐ๋ณํ๊ณผ ๋์ผ์ฑ์ ์ ์ง๋๊ธฐ ์ํด v์ ์
๋ฐ์ดํธ๋ ๋ค์๊ณผ ๊ฐ์์ ธ์ผ ํจ
vt+1=mฮทt+1ฮทt+ฮทt+11nโโฝl(x,wt)v_{t+1} = m \frac{\eta_{t+1}}{\eta_t} + \eta_{t+1}\frac{1}{n}\sum \bigtriangledown l (x, w_t)
- ฮทt+1ฮทt\frac{\eta_{t+1}}{\eta_{t}} : ๋ชจ๋ฉํ ๋ณด์ (momentum correction)
- ฮทt+1>>ฮทt\eta_{t+1} >> \eta_t ์ธ ๊ฒฝ์ฐ ํ๋ จ์ ์์ ํ์ํค๋ ๋ฐ ์ค์ํจ.
- ํ์ต๋ฅ ฮท๋ฅผ ์
๋ฐ์ดํธ ํ
์์ ํก์ํ๋ ๊ฒ
Gradient aggregation
- k ๊ฐ์ worker๋ง๋ค ๊ฐ๊ฐ ํฌ๊ธฐ๊ฐ n์ธ ๋ฏธ๋๋ฐฐ์น๋ฅผ ๊ฐ์ง๊ณ ์๋ ๊ฒฝ์ฐ, (4)์ ๋ฐ๋ผ ๊ทธ๋ ์ด๋์ธํธ ์ง๊ณ๋ kn๊ฐ์ ์ ์ฒด ์์ ์ธํธ์ ๋ํด ๋ค์ ์์ด ์ํ๋์ด์ผ ํจ.
1knโjโxโBjl(x,wt)\frac{1}{kn}\sum_j\sum_{x \in B_j}l(x, w_t)
- ์์ค ๋ ์ด์ด
- ์ผ๋ฐ์ ์ผ๋ก ์์ฒด์ ๋ก์ปฌ ์ ๋ ฅ์ ๋ํ ํ๊ท ์์ค์ ๊ณ์ฐํ๋ ๋ฐฉ์์ผ๋ก ๊ตฌํ
- ๊ฐ ์์ ์์ ์์คโl(x,wt)/n\sum l(x, w_t)/n์ ๊ณ์ฐํ๋ ๊ฒ๊ณผ ๋์ผํจ
- ์๊ธฐ ๋ด์ฉ์ ๊ณ ๋ คํ๋ฉด ์ฌ๋ฐ๋ฅธ ์ง๊ณ๋ ๋๋ฝ๋ 1/k์์๋ฅผ ๋ณต์ํ๊ธฐ ์ํด k๊ฐ์ ๊ทธ๋ ๋์ธํธ๋ฅผ ํ๊ท ํํด์ผํจ.
- ๊ทธ๋ฌ๋, allreduce [11]์ ๊ฐ์ ํ์ค ํต์ ์์๋ ํ๊ท ์ด ์๋ ํฉ์ ์ํ
- ๋ฐ๋ผ์ 1/k์ค์ผ์ผ๋ง์ ์์ค์ ํก์ํ๋ ๊ฒ์ด ๋ ํจ์จ์ ์ด๋ฉฐ, ์ด ๊ฒฝ์ฐ์๋ ์์ค์ ๋ํ ์ ๋ ฅ์ ๊ทธ๋ ์ด๋์ธํธ๋ง ์ค์ผ์ผ๋งํ๋ฉด ๋๋ฏ๋ก ์ ์ฒด ๊ทธ๋ ์ด๋์ธํธ ๋ฒกํฐ๋ฅผ ์ค์ผ์ผ๋งํ ํ์๊ฐ ์์.Remark 3: Normalize the per-worker loss by total minibatch size kn, not per-worker size n
- ๋ํ โk๋ฅผ ์ทจ์โํ๊ธฐ์ํดฮท^=ฮท\hat{\eta} = \eta ์ค์ ํ๊ณ ์์ค์ 1/n์ผ๋ก ์ ๊ทํํ๋ ๊ฒ์ ์๋ชป๋ ์ ์์
- ์ด๋ ์๋ชป๋ ๊ฐ์ค์น ๊ฐ์ ๋ก ์ด์ด์ง (โ remark 1 ์ฐธ์กฐ)
Data shuffling
- SGD๋ ์ผ๋ฐ์ ์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ๋ณต์์ถ์ถ๋ก ์ํ๋งํ๋ ํ๋ก์ธ์ค๋ก ๋ถ์
- ์ค์ ๋ก ์ผ๋ฐ์ ์ธ SGD ๊ตฌํ์์๋ ๊ฐ SGD ์ํฌํฌ๋ง๋ค ํ๋ จ ์ธํธ๋ฅผ ๋ฌด์์๋ก ์์ด์ฃผ๋ ๊ฒ์ด ํํ๋ฉฐ, ์ด๋ ๋ ๋์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ ์์์.
- ๋ฌด์์ ์
ํ๋ง์ ์ฌ์ฉํ๋ ๊ธฐ์ค์ (K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016)๊ณผ ๊ณต์ ํ ๋น๊ต๋ฅผ ์ ๊ณตํ๊ธฐ ์ํด, k worker์ ์ํด ์ํ๋ ํ ์ํฌํฌ์ ์ํ์ ํ๋ จ ์ธํธ์ ๋จ์ผ ์ผ๊ด๋ ๋ฌด์์ ์๊ธฐ์์ ๊ฐ์ ธ์ค๋๋ก ํจ.
- ์ด๋ฅผ ๋ฌ์ฑํ๊ธฐ ์ํด ๊ฐ ์ํฌํฌ๋ง๋ค k ๋ถ๋ถ์ผ๋ก ๋๋์ด์ง ๋ฌด์์ ์๊ธฐ๋ฅผ ์ฌ์ฉํ๋ฉฐ, ๊ฐ ๋ถ๋ถ์ k worker ์ค ํ๋์ ์ํด ์ฒ๋ฆฌ
- ์ฌ๋ฌ worker์์ ๋ฌด์์ ์ ํ๋ง์ ์ฌ๋ฐ๋ฅด๊ฒ ๊ตฌํํ์ง ์์ผ๋ฉด ๋๋ ทํ ๋ค๋ฅธ ๋์์ ์ ๋ฐ ๊ฐ๋ฅํจโ๊ฒฐ๊ณผ์ ๊ฒฐ๋ก ์ ์ค์ผ
Remark 4: Use a single random shuffling of the training data (per epoch) that is divided amongst all k workers
- ํ๋์ Big Basin ์๋ฒ์์์ 8๊ฐ์ GPU๋ฅผ ๋์ด์ ๊ท๋ชจ๋ฅผ ํ์ฅํ๋ ค๋ฉด , ๊ทธ๋ ์ด๋์ธํธ ์ง๊ณ๋ ๋คํธ์ํฌ ์์ ์ฌ๋ฌ ์๋ฒ์ ๊ฑธ์ณ ์ด๋ค์ ธ์ผ ํจ. ๊ฑฐ์ ์๋ฒฝํ ์ ํ ํ์ฅ์ ํ์ฉํ๊ธฐ ์ํด์๋ ์ง๊ณ๊ฐ ์ญ์ ํ์ ๋ณ๋ ฌ๋ก ์ํ๋์ด์ผ ํฉ๋๋ค. ์ด๋ ์ธต ๊ฐ์ ๊ทธ๋ ์ด๋์ธํธ ๊ฐ์ ๋ฐ์ดํฐ ์์กด์ฑ์ด ์๊ธฐ ๋๋ฌธ์ ๊ฐ๋ฅํฉ๋๋ค. ๋ฐ๋ผ์ ํ ์ธต์ ๊ทธ๋ ์ด๋์ธํธ๊ฐ ๊ณ์ฐ๋๋ฉด ์ฆ์ ํด๋น ์ธต์ ๋ํ ๊ทธ๋ ์ด๋์ธํธ๊ฐ ์์ ์๋ค ๊ฐ์ ์ง๊ณ๋๊ณ , ๋์์ ๋ค์ ์ธต์ ๋ํ ๊ทธ๋ ์ด๋์ธํธ ๊ณ์ฐ์ด ๊ณ์๋ฉ๋๋ค. ๋ค์์์๋ ์ด์ ๋ํ ์์ธํ ๋ด์ฉ์ ์ค๋ช ํ๊ฒ ์ต๋๋ค.
4.1 Gradient Aggregation
- ๊ฐ ๊ทธ๋ ์ด๋์ธํธ์ ๋ํด ์ง๊ณ๋ MPI ์งํฉ ์ฐ์ฐ(MPI Allreduce)[11]์ ์ ์ฌํ allreduce ์์ ์ ์ฌ์ฉํ์ฌ ์ํ๋ฉ๋๋ค. Allreduce๊ฐ ์์๋๊ธฐ ์ ์ ๊ฐ GPU๋ ๋ก์ปฌ๋ก ๊ณ์ฐ๋ ๊ทธ๋ ์ด๋์ธํธ๋ฅผ ๊ฐ๊ณ ์๊ณ , allreduce๊ฐ ์๋ฃ๋๋ฉด ๊ฐ GPU๋ ๋ชจ๋ k๊ฐ์ ๊ทธ๋ ์ด๋์ธํธ์ ํฉ์ ๊ฐ์ต๋๋ค. ๋งค๊ฐ๋ณ์์ ์๊ฐ ์ฆ๊ฐํ๊ณ GPU์ ๊ณ์ฐ ์ฑ๋ฅ์ด ํฅ์๋จ์ ๋ฐ๋ผ ์ง๊ณ ๋น์ฉ์ backprop ๋จ๊ณ์์ ์จ๊ธฐ๊ธฐ๊ฐ ๋ ์ด๋ ค์์ง๋๋ค. ์ด๋ฌํ ํจ๊ณผ๋ฅผ ๊ทน๋ณตํ๊ธฐ ์ํ ํ๋ จ ๊ธฐ์ ์ ์ด ์์ ์ ๋ฒ์๋ฅผ ๋ฒ์ด๋ฉ๋๋ค (์: ์์ํ๋ ๊ทธ๋ ์ด๋์ธํธ(quantized gradient) [18], ๋ธ๋ก-๋ชจ๋ฉํ SGD [6]). ๊ทธ๋ฌ๋ ์ด ์์ ์ ๊ท๋ชจ์์๋ ์ต์ ํ๋ allreduce ๊ตฌํ์ ์ฌ์ฉํ์ฌ ๊ฑฐ์ ์ ํ์ ์ธ SGD ์ค์ผ์ผ๋ง์ ๋ฌ์ฑํ ์ ์์ด, ์ง๋จ ํต์ ์ด ๋ณ๋ชฉ์ด ๋์ง ์์์ต๋๋ค.
- allreduce ๊ตฌํ 3๋จ๊ณ : ์๋ฒ ๋ด๋ถ ๋ฐ ์๋ฒ ๊ฐ ํต์ ์ ์ํด์.
๐ก1. ์๋ฒ ๋ด์ 8๊ฐ GPU์์ ๊ฐ๊ฐ์ ๋ฒํผ๊ฐ ๊ฐ ์๋ฒ์ ๋ํด ํ๋์ ๋จ์ผ ๋ฒํผ๋ก ํฉ์ฐ
- ๊ฒฐ๊ณผ ๋ฒํผ๋ ๋ชจ๋ ์๋ฒ ๊ฐ์ ๊ณต์ ๋์ด ํฉ์ฐ
- ๊ฒฐ๊ณผ๊ฐ ๊ฐ GPU๋ก ๋ธ๋ก๋์บ์คํธ๋จ
- ์๋ฒ ๊ฐ allreduce๋ฅผ ์ํด ๋์ญํญ ์ ํ ์๋๋ฆฌ์ค์ ๋ํ ๋ ๊ฐ์ง ์ต๊ณ ์ ์๊ณ ๋ฆฌ์ฆ์ ๊ตฌํ
- ์ฌ๊ท์ ์ธ ๋ฐ๊ฐ ๋ฐ ๋ฐฐ๊ฐ ์๊ณ ๋ฆฌ์ฆ
- 2 log2(p) ํต์ ๋จ๊ณ๋ก ์ด๋ฃจ์ด์ง
- reduce-scatter ์งํฉ์ ์ด์ด allgather๋ก ์ด๋ฃจ์ด์ ธ ์์
- process
- ์๋ฒ๋ ์์ผ๋ก ํต์ ํ๋ฉฐ (๋ญํฌ 0์ 1๊ณผ, 2๋ 3๊ณผ ๋ฑ๋ฑ), ์ ๋ ฅ ๋ฒํผ์ ๋ค๋ฅธ ๋ฐ์ชฝ์ ๋ํด ๋ณด๋ด๊ณ ๋ฐ์ (๋ญํฌ 0์ ๋ฒํผ์ ๋ ๋ฒ์งธ ์ ๋ฐ์ 1์๊ฒ ๋ณด๋ด๊ณ 1๋ก๋ถํฐ ๋ฒํผ์ ์ฒซ ๋ฒ์งธ ์ ๋ฐ์ ๋ฐ์)
- ๋ค์ ๋จ๊ณ๋ก ์งํํ๊ธฐ ์ ์ ์์ ๋ ๋ฐ์ดํฐ์ ๋ํ ์ถ์๊ฐ ์ํ๋๊ณ , ๋ค์ ๋จ๊ณ์์๋ ๋ชฉ์ ์ง ๋ญํฌ๊น์ง์ ๊ฑฐ๋ฆฌ๊ฐ ๋ ๋ฐฐ๋ก ๋์ด๋๋ฉด์ ๋ณด๋ด๊ณ ๋ฐ์ ๋ฐ์ดํฐ๊ฐ ์ ๋ฐ์ผ๋ก ์ค์ด๋ค์
- reduce-scatter ๋จ๊ณ๊ฐ ์๋ฃ๋๋ฉด ๊ฐ ์๋ฒ์๋ ์ต์ข ์ถ์๋ ๋ฒกํฐ์ ์ผ๋ถ๊ฐ ์์.
- allgather ๋จ๊ณ
- reduce-scatter์์์ ํต์ ํจํด์ ์ญ์ผ๋ก ์ถ์ ํ์ฌ ์ต์ข ์ถ์๋ ๋ฒกํฐ์ ์ผ๋ถ๋ฅผ ๊ฐ๋จํ ์ฐ๊ฒฐ
- ๊ฐ ์๋ฒ์์ reduce-scatter์์ ๋ณด๋ด๊ณ ์๋ ๋ฒํผ์ ์ผ๋ถ๊ฐ allgather์์ ์์ ๋๊ณ , ๋ฐ๋ ๋ถ๋ถ์ ์ด์ ๋ณด๋ด์ง
- ์๋ฒ์ ์๊ฐ 2์ ๊ฑฐ๋ญ์ ๊ณฑ์ด ์๋ ๊ฒฝ์ฐ์ ๋์ํ๊ธฐ ์ํด ์ด์ง ๋ธ๋ก ์๊ณ ๋ฆฌ์ฆ [30]์ ์ฌ์ฉ
- ์ด์ง ๋ธ๋ก ์๊ณ ๋ฆฌ์ฆ : ์๋ฒ๊ฐ 2์ ๊ฑฐ๋ญ์ ๊ณฑ ๋ธ๋ก์ผ๋ก ๋ถํ ๋๊ณ ๋ ๊ฐ์ ์ถ๊ฐ ํต์ ๋จ๊ณ๊ฐ ์ฌ์ฉ๋๋ ๋ฐ๊ฐ/๋ฐฐ๊ฐ ์๊ณ ๋ฆฌ์ฆ์ ์ผ๋ฐํ๋ ๋ฒ์
- ์๋ ๋ธ๋ก ๋ด๋ถ reduce-scatter ํ์ ๋ธ๋ก ๋ด๋ถ allgather ์ ์ ๊ฐ๊ฐ ํ ๋ฒ ์ฌ์ฉ๋จ
- ๊ฑฐ๋ญ์ ๊ณฑ์ด ์๋ ๊ฒฝ์ฐ ์ผ๋ถ ๋ถํ ๋ถ๊ท ํ์ด ๊ฑฐ๋ญ์ ๊ณฑ๊ณผ ๋น๊ตํ์ฌ ๋ฐ์ํ์ง๋ง, ํ ๋ ผ๋ฌธ์์๋ ์ฑ๋ฅ์ ํ๋ฅผ ๊ด์ฐฐํ์ง ๋ชปํจ.
- process
- ๋ฒํท ์๊ณ ๋ฆฌ์ฆ (๋ง ์๊ณ ๋ฆฌ์ฆ์ด๋ผ๊ณ ๋ ํจ)
- 2(pโ1) ๋จ๊ณ๋ก ์ด๋ฃจ์ด์ง.
- ๋ฐ๊ฐ/๋ฐฐ๊ฐ ์๊ณ ๋ฆฌ์ฆ์ด ๋๊ฐ latency-limited ์๋๋ฆฌ์ค์์ ๋ ๋น ๋ฅด๊ฒ ๋์(์ฆ, ์์ ๋ฒํผ ํฌ๊ธฐ ๋ฐ/๋๋ ํฐ ์๋ฒ ์์ ๊ฒฝ์ฐ ,์ฝ 3).
- ์ฌ๊ท์ ์ธ ๋ฐ๊ฐ ๋ฐ ๋ฐฐ๊ฐ ์๊ณ ๋ฆฌ์ฆ
4.2 software
- ํต์ ์์ง์ ์ํ allreduce ์๊ณ ๋ฆฌ์ฆ์ Gloo ๊นํ๋ธ์ ์์.
- ๋ณ๋ ฌ๋ก ์ฌ๋ฌ allreduce ์ธ์คํด์ค๋ฅผ ์คํํ๊ธฐ ์ํด ์ถ๊ฐ ๋๊ธฐํ๊ฐ ํ์ํ์ง ์๋ ์ฌ๋ฌ ํต์ context๋ฅผ ์ง์
- ๋ก์ปฌ ๋ฆฌ๋์ ๋ฐ ๋ธ๋ก๋์บ์คํธ (๋จ๊ณ (1) ๋ฐ (3)๋ก ์ค๋ช ๋จ)์ ๊ฐ๋ฅํ ๊ฒฝ์ฐ ์๋ฒ ๊ฐ allreduce์ ํ์ดํ๋ผ์ธํ๋จ
- Caffe2
- ํ๋ จ ๋ฐ๋ณต์ ๋ํ๋ด๋ ์ปดํจํธ ๊ทธ๋ํ์ ๋ฉํฐ์ค๋ ๋ ์คํ์ ์ง์
- ์๋ธ๊ทธ๋ํ ๊ฐ์ ๋ฐ์ดํฐ ์์กด์ฑ์ด ์๋ ๊ฒฝ์ฐ ์ฌ๋ฌ ์ค๋ ๋๊ฐ ๊ทธ ์๋ธ๊ทธ๋ํ๋ฅผ ๋ณ๋ ฌ๋ก ์คํ๊ฐ๋ฅ
- ์ด๋ฅผ backprop์ ์ ์ฉํ๋ฉด ๋ก์ปฌ ๊ทธ๋ ์ด๋์ธํธ๊ฐ ์์ฐจ์ ์ผ๋ก ๊ณ์ฐ๋ ์ ์๊ณ , allreduce๋ ๊ฐ์ค์น ์
๋ฐ์ดํธ์ ๊ด๋ จ์ด ์์
- ์ด๋ backprop ์ค์ ์คํ ๊ฐ๋ฅํ ์๋ธ๊ทธ๋ํ ์งํฉ์ด ์คํ ๊ฐ๋ฅํ ์๋ธ๊ทธ๋ํ๋ฅผ ์คํํ๋ ์๋๋ณด๋ค ๋ ๋นจ๋ฆฌ ์ฆ๊ฐํ ์ ์์์ ์๋ฏธ
- allreduce๋ฅผ ํฌํจํ๋ ์๋ธ๊ทธ๋ํ์ ๊ฒฝ์ฐ ๋ชจ๋ ์๋ฒ๊ฐ ์คํ ๊ฐ๋ฅํ ์๋ธ๊ทธ๋ํ ์งํฉ์์ ๋์ผํ ์๋ธ๊ทธ๋ํ๋ฅผ ์คํํ๋๋ก ์ ํํด์ผ ํจ
- ๊ทธ๋ ์ง ์์ผ๋ฉด ์๋ฒ๊ฐ ์๋ก ๊ต์ฐจํ์ง ์๋ ์๋ธ๊ทธ๋ํ ์งํฉ์ ์คํํ๋ ค๊ณ ํ ๋ ๋ถ์ฐ ๋ฐ๋๋ฝ์ด ๋ฐ์ํ ์ํ์ด ์์.
- allreduce๊ฐ ์งํฉ ์ฐ์ฐ์ด๊ธฐ ๋๋ฌธ์ ์๋ฒ๋ ํ์์์ ๋ ๊ฑฐ์
- ์ฌ๋ฐ๋ฅธ ์คํ์ ๋ณด์ฅํ๊ธฐ ์ํด ์ด๋ฌํ ์๋ธ๊ทธ๋ํ์ ๋ํ ๋ถ๋ถ์ ์ธ ์์๋ฅผ ๋ถ์ฌํด์ผํจ.
- ์ํ์ ์ธ ์ ์ด ์ ๋ ฅ์ ์ฌ์ฉํ์ฌ ๊ตฌํ๋จ
- n๋ฒ์งธ allreduce์ ์๋ฃ๊ฐ (n + c)๋ฒ์งธ allreduce์ ์คํ์ ๋ธ๋ก ํด์ ํ๊ฒ ๋จ
- ์ฌ๊ธฐ์ c๋ ์ต๋ ๋์ allreduce ์คํ ํ์
- ์ด ์ซ์๋ ์ ์ฒด ์ปดํจํธ ๊ทธ๋ํ๋ฅผ ์คํํ๋ ๋ฐ ์ฌ์ฉ๋๋ ์ค๋ ๋ ์๋ณด๋ค ๋ฎ๊ฒ ์ ํํด์ผ ํจ.