๐ก
imageNet์ dataset์์ ํฐ ๋ฏธ๋๋ฐฐ์น๋ ์ต์ ํ๋ฅผ ์ํค๋๋ฐ ์ด๋ ต๋ค๋๊ฒ์ ๋ณด์ฌ์ค(ํ์ง๋ง ์ด ๋ฌธ์ ๊ฐ ํด๊ฒฐ๋๋ฉด ํ๋ จ๋ชจ๋ธ์ ์ข์ ์ผ๋ฐํ๋ฅผ ๊ฐ์ง). ์ด ๋ฌธ์ ๋ฅผ ์ํด ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ์ ํจ์๋ก ํ์ต์๋๋ฅผ ์กฐ์ ํ๊ธฐ ์ํ ์ด๋งค๊ฐ๋ณ์ ์๋ ์ ํ ํ์ฅ ๊ท์น์ ์ฑํํจ. ๋ํ ํ๋ จ ์ด๊ธฐ์ ์ต์ ํ ๋ฌธ์ ๋ฅผ ๊ทน๋ณตํ๋ ์๋ก์ด ์๋ฐ์
๋ฐฉ์ ๊ฐ๋ฐ.
This paper’s goal
- ๋ถ์ฐ์ ๋น๋๊ธฐ์ SGD(stochastic gradient descent)๊ฐ ๋๊ท๋ชจ ํธ๋ ์ด๋์ ์ ํฉํ๋ค๋ ๊ฒ์ ์ฆ๋ช ํ๊ณ ์ค์ฉ์ ์ธ ๊ฐ์ด๋ ์ ๋ฌํ๊ธฐ
ํฐ ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ๋ฅผ ๋ค๋ฃจ๊ธฐ ์ํด ํ์ํ ๊ฒ
- hyper-parameter-free linear scaling rule๋ฅผ ์ด์ฉํ์ฌ learning rate ์กฐ์ ํ๊ธฐ
- ์ด ๊ฐ์ด๋๋ผ์ธ์ earlier worker์์ ์ค๋ฆฝ๋จ
- ๊ฒฝํ์ ํ๊ณ๋ ์ ์ดํด๋์ง ์์ผ๋ฉฐ, ๋น๊ณต์์ ์ด๊ธฐ ๋๋ฌธ์ research community์ ์ ์๋ ค์ง์ง ์์
- ์ด ๋ฒ์น์ ์ฑ๊ณต์ ์ผ๋ก ์ ์ฉํ๊ธฐ ์ํด ์๋ก์ด ์ค๋น์ ๋ต ์ ์
- ์ค๋น์ ๋ต = ์ด๊ธฐ ์ต์ ํ๋ฌธ์ ๋ฅผ ๊ทน๋ณตํ๊ธฐ ์ํด ๋ฎ์ learning rate ์ฌ์ฉํ๊ธฐ
- ์์ค L(w)์
L(w)=1โฃXโฃ∑x∈Xl(x,w)L(w) = \frac{1}{|X|}\sum_{x\in X}l(x, w)
- w : ๋คํธ์ํฌ์ ๊ฐ์ค์น(weight of a network)
- X : ๋ ์ด๋ธ์ด ์ง์ ๋ training set
- l(x, w) : ์ํx∈Xx \in X๊ณผ ๋ผ๋ฒจ y์์ ๊ณ์ฐ๋ ์์ค
- l : ๋ถ๋ฅ์์ค(๊ต์ฐจ์ํธ๋กํผ)์ w์๋ํ ์ ๊ท์์ค์ ํฉ
- Minibatch stochastic gradient decent๋ ๋ฏธ๋๋ฐฐ์น์์ ์๋ํจ. ๋ณดํต ๊ฐ๋จํ๊ฒ SGD๋ผ๊ณ ํ๋ฉฐ, ๋ค์ ์
๋ฐ์ดํธ๋ฅผ ์ํํจ
wt+1=wt−η1n∑x∈Bโฝ(x,wt)w_{t+1} = w_t - \eta\frac{1}{n}\sum_{x \in B} \bigtriangledown (x, w_t)
- BB : X์์ ์ํ๋ง๋ ๋ฏธ๋๋ฐฐ์น์
- n=โฃBโฃn= |B | : ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ
- η\eta : learning rate
- t : iteration index
- ์ด ๋ ผ๋ฌธ์์๋ ์ค์ ๋ก momentum SGD๋ฅผ ์ฌ์ฉํจ.
2.1 Learning rates for Large Minibatches
- ๋ชฉํ
- ์์ ๋ฏธ๋๋ฐฐ์น ๋์ ํฐ ๋ฏธ๋๋ฐฐ์น๋ฅผ ์ฐ๋ฉด์ training๊ณผ ์ผ๋ฐํ ์ ํ๋ ์ ์งํ๊ธฐ⇒ worker๋ณ ์์ ๋์ ์ค์ด๊ฑฐ๋ ๋ชจ๋ธ ์ ํ๋๋ฅผ ์ ํ์ํค์ง ์๊ณ ๊ฐ๋จํ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ๋ฌ worker๋ก ํ์ฅํ ์ ์๊ธฐ ๋๋ฌธ์ ๋ถ์ฐํ์ต์์ ํนํ ์ค์!
- (worker์ GPU๋ฅผ ๊ฐ์ ์๋ฏธ๋ก ์ฌ์ฉ)
- learning rate scaling rule์ด ๊ด๋ฒ์ํ ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ์ ๋๋ผ์ธ์ ๋๋ก ํจ๊ณผ์ ์linear scaling rule : ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ์ k๋ฅผ ๊ณฑํ๋ฉด, ํ์ต๋ฅ ์ k๋ฅผ ๊ณฑํ๊ธฐ
- ๋ค๋ฅธ hyper-parameter(weight decay)๋ ๋ฐ๊พธ์ง ์๊ธฐ
- linear scaling rule์ ์์ ๋ฏธ๋๋ฐฐ์น์ ํฐ ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ ์ ํ๋๋ฅผ ์ผ์น์ํค๋๋ฐ ๋์์ ์ค. ๋, training curve๋ ๋งค์น ์ํด⇒ ์คํ์ ์ ๋น ๋ฅด๊ฒ ๋น๊ตํ๊ณ ๋๋ฒ๊น ํ ์ ์๊ฒ๋จ
- interpretation
- leaner scaling rule๊ณผ ์ ํจ๊ณผ์ ์ธ์ง ์ค๋ช
ํ๊ฒ ์→ iteration์ด t, weight๊ฐ w์ธ ๋คํธ์ํฌ,0≤j<k0 \leq j <k์ ๋ํด k๊ฐ์ ๋ฏธ๋๋ฐฐ์นBjB_j์ ์ํ์ค๋ฅผ ๊ณ ๋ คํด์ผํจ
- k SGD iteration(์์ ๋ฏธ๋๋ฐฐ์นBjB_j, learning rate๊ฐη\eta)์ single iteration(์ฌ์ด์ฆ๊ฐ kn ์ธ ํฐ ๋ฏธ๋๋งค์น∪jBj\cup_j B_j์ ํ์ต๋ฅ η^\hat{\eta})์ ์คํ ํจ๊ณผ๋ฅผ ๋น๊ต
- learning rate๊ฐη\eta, ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ๊ฐ n์ธ SGD์ k iteration
wt+k=wt−η1n∑j<k∑x∈Bjโฝl(x,wt+j)−−−−(3)w_{t+k} = w_t -\eta \frac{1}{n} \sum_{j<k}\sum_{x \in B_j} \bigtriangledown l (x, w_{t+j})---- (3)
- learning rate๊ฐη^\hat{\eta}, ์ฌ์ด์ฆ๊ฐ kn ํฐ ๋ฏธ๋๋ฐฐ์น∪jBj\cup _j B_j๋ฅผ ์ฌ์ฉํ ๋จ์ผ ๋จ๊ณ๋ฅผ ์ฌ์ฉํ๋ฉด
w^t+1=wt−η^1kn∑j<k∑x∈Bjโฝl(x,wt)−−−−(4)\hat{w}_{t+1} = w_t - \hat{\eta}\frac{1}{kn}\sum_{j<k}\sum_{x \in B_j}\bigtriangledown l (x, w_t)---- (4) - learning rate๊ฐη\eta, ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ๊ฐ n์ธ SGD์ k iteration
- ์์ํ๊ฒ์ฒ๋ผ ์ ๋ฐ์ดํธ ๊ฐ์ ์๋ก ๋ค๋ฅด๋ฉฐ,w^t+1=wt+k\hat{w}_{t+1} = w_{t+k}์ผ ํ๋ฅ ์ ์์.
- ํ์ง๋ง, j<k์ ๋ํดโฝl(x,wt)≈โฝl(x,wt+j)\bigtriangledown l (x, w_t) \approx \bigtriangledown l(x, w_{t+j})๋ ๊ฐ์ ํ ์ ์๋ ๊ฒฝ์ฐη^=kη\hat{\eta} = k\eta๋ก ์ค์ ํ ๋w^t+1≈wt+k\hat{w}_{t+1} \approx w_{t+k}๊ฐ ์์ฑ๋๊ณ , ์์ ๋ฏธ๋๋ฐฐ์น SGD์ ํฐ ๋ฏธ๋๋ฐฐ์น SGD๋ ์๋ก ์ ์ฌํจ.
- ๋ํ, ๊ฐ๋ ฅํ ๊ฐ์ ์์๋ ๋ถ๊ตฌํ๊ณ ์ด๊ฒ์ด ์ฌ์ค์ด๋ผ๋ฉด,η^=kη\hat{\eta} = k\eta๋ฅผ ์ค์ ํ ๊ฒฝ์ฐ์๋ง ๋ ์ ๋ฐ์ดํธ๊ฐ ์ ์ฌํ๋ค๊ณ ๊ฐ์กฐํจ.
- k SGD iteration(์์ ๋ฏธ๋๋ฐฐ์นBjB_j, learning rate๊ฐη\eta)์ single iteration(์ฌ์ด์ฆ๊ฐ kn ์ธ ํฐ ๋ฏธ๋๋งค์น∪jBj\cup_j B_j์ ํ์ต๋ฅ η^\hat{\eta})์ ์คํ ํจ๊ณผ๋ฅผ ๋น๊ต
- ์ ํด์์ linear scaling rule์ด ์ ์ฉ๋๊ธฐ๋ฅผ ๋ฐ๋ผ๋ ํ๊ฐ์ง ๊ฒฝ์ฐ์ ๋ํ ์ง๊ด์ ์ ๊ณตํจ.η^=kη\hat{\eta} = k \eta(๋ฐ ์ค๋น)์ธ ์คํ์์ ์์ ๋ฏธ๋๋ฐฐ์น SGD์ ํฐ ๋ฏธ๋๋ฐฐ์น SGD๋ ๋ชจ๋ธ์์์ ๊ฐ์ ๋ง์ง๋ง ์ ํ๋๋ฟ๋ง ์๋๋ผ training curve๋ํ ๊ฝค ๋งค์น๋จ. ์คํ๊ฒฐ๊ณผ๋ ์ ๊ทผ์ฌ์น๊ฐ ๋๊ท๋ชจ ์ค์ ๋ฐ์ดํฐ์์ ์ ํจํ ์ ์๋ค๊ณ ์ ์ํ๋ค.
- ๊ทธ๋ฌ๋ ์กฐ๊ฑดโฝl(s,wt)≈โฝl(x,wt+j)\bigtriangledown l (s, w_t) \approx \bigtriangledown l (x, w_{t+j})๊ฐ ์ ์ง๋์ง ์๋ ๋๊ฐ์ง ๊ฒฝ์ฐ๊ฐ ์์
- ๋คํธ์ํฌ๊ฐ ๋น ๋ฅด๊ฒ ๋ฐ๋๋์ ์ด๊ธฐ ํ๋ จ(initial training)⇒ 2.2์ ์ค๋น๋จ๊ณ๋ฐฉ๋ฒ์ ์ฌ์ฉํด ํด๊ฒฐ
- ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ๊ฐ ๋ฌดํ์ ํ์ฅ๋ ์ ์์.
- ๊ฒฐ๊ณผ๋ ๋ค์ํ ๋ฒ์์ ํฌ๊ธฐ์์ ์์ ์ ์ด์ง๋ง, ํน์ ์ง์ ์ ๋์ผ๋ฉด ์ ํ๋๊ฐ ๋น ๋ฅด๊ฒ ๋จ์ด์ง
- ์ด ์ง์ ์ imageNet ์คํ์์ 8k๋ณด๋ค ํผ
- leaner scaling rule๊ณผ ์ ํจ๊ณผ์ ์ธ์ง ์ค๋ช
ํ๊ฒ ์→ iteration์ด t, weight๊ฐ w์ธ ๋คํธ์ํฌ,0≤j<k0 \leq j <k์ ๋ํด k๊ฐ์ ๋ฏธ๋๋ฐฐ์นBjB_j์ ์ํ์ค๋ฅผ ๊ณ ๋ คํด์ผํจ
- Discussion
- ์๊ธฐ linear scaling rule์ Krizhevsky์ ์ํด ์ฑํ๋จ. ํ์ง๋ง Krizhevsky๋ ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ๊ฐ 128์์ 1024๋ก ์ฆ๊ฐํ ๋ ์๋ฌ๊ฐ 1%๊ฐ ์ฆ๊ฐํ๋ค๊ณ ๋ณด๊ณ ํจ. ๋ฐ๋ฉด์ ์ฐ๋ฆฌ๋ ํจ์ฌ ๋ ๊ด๋ฒ์ํ ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ ๋ฒ์์์ ์ด๋ป๊ฒ ์ ํ๋๋ฅผ ์ ์งํ๋์ง ๋ณด์ฌ์ค.
- Chen์ ์๋ง์ ๋ถ์ฐ SGD ๋ณํ์ ๋น๊ต๋ฅผ ์ ์ํจ. ๊ทธ๋ค์ linear scaling rule์ ์ฌ์ฉํ์ง๋ง, ๋ฏธ๋๋ฐฐ์น ๊ธฐ์ค์ ์ ์ค์ ํ์ง ์์.
- Li๋ ์๋ ดํ ์ ํ๋ ์์ค ์์ด ๋ฏธ๋๋ฐฐ์น๊ฐ 5210๊น์ง์ธ ๋ถ์ฐ imageNet ํ๋ จ์ ๋ณด์ฌ์ค. ํ์ง๋ง ์ฐ๋ฆฌ๊ฐ ํต์ฌ์ ์ผ๋ก ๊ธฐ์ ๋ฏธ๋๋ฐฐ์น ์ฌ์ด์ฆ๋ก ํ์ต๋ฅ ์กฐ์ ์ ์ํด hyper-parameter search rule์ ๋ณด์ฌ์ฃผ์ง ๋ชปํจ.
- ์ต๊ทผ ์ฐ๊ตฌ์์, Bottou ๋ฑ [4] (§4.2)์ ๋ฏธ๋๋ฐฐ์น์ ์ด๋ก ์ ์ฅ๋จ์ ์ ๊ฒํ ํ๊ณ ์ ํ ์ค์ผ์ผ๋ง ๊ท์น์ ๋ฐ๋ผ solver๊ฐ ๋ณธ ์์ ์์ ํจ์๋ก ๋์ผํ ํ๋ จ ๊ณก์ ์ ๋ฐ๋ฅด๋๊ฒ์ ๋ณด์ฌ์ค. ํ์ต๋ฅ ์ ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ์ ๋ ๋ฆฝ์ ์ธ ์ต๋ ์๋๋ฅผ ์ด๊ณผํด์๋ ์ ๋๋ฉฐ(๋ฐ๋ผ์ ์์ ์ด ์ ๋นํ๋จ), ์ ๋ก ์๋ ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ๋ก ์ด๋ฌํ ์ด๋ก ์ ์คํ์ ์ผ๋ก ๊ฒ์ฆํจ.
2.2 warmup
- ๋
ผ์ํ ๋๋ก, ๋๊ท๋ชจ ๋ฏธ๋๋ฐฐ์น(์: 8k)์ ๊ฒฝ์ฐ, ์ ๊ฒฝ๋ง์ด ๋น ๋ฅด๊ฒ ๋ณํ ๋ ์ ํ ์ค์ผ์ผ๋ง ๊ท์น์ด ๋ถ๊ดด๋จ. ์ด๋ ํ๋ จ ์ด๊ธฐ์ ํํ ๋ฐ์ํจ→ ์ด ๋ฌธ์ ๋ฅผ ์ ์ ํ ์ค๊ณ๋ ์์
[16]์ ์ํด ์ํ์ํฌ ์ ์์.
- ํ๋ จ ์์ ์์ ๋ ๊ณต๊ฒฉ์ ์ธ ํ์ต๋ฅ ์ ์ฌ์ฉํ๋ ์ ๋ต
- Constant warmup
- [16]์์ ์ ์๋ ์์ ์ ๋ต์ ํ๋ จ์ ์ฒ์ ๋ช ์ํฌํฌ ๋์ ๋ฎ์ ๊ณ ์ ํ์ต๋ฅ ์ ์ฌ์ฉ
- ์ฌ์ ํ๋ จ๋ ๋ ์ด์ด๋ฅผ ์๋ก ์ด๊ธฐํ๋ ๋ ์ด์ด์ ํจ๊ป ์ธ๋ฐํ๊ฒ ์กฐ์ ํ๋ ๊ฐ์ฒด ๊ฒ์ถ ๋ฐ ์ธ๊ทธ๋ฉํ ์ด์ ๋ฐฉ๋ฒ์ Constant warmup์ด ์ ์ฉ
- ๋๊ท๋ชจ ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ kn์ผ๋ก ํ ImageNet ์คํ์์, ์ฒ์ 5 epochs๋์ ํ์ต๋ฅ η๋ก ํ๋ จํจ. ๊ทธ ํ ๋ชฉํ ํ์ต๋ฅ η^=kη\hat{\eta} = k\eta๋ก ๋์๊ฐ๋ ค๊ณ ํ์. ๊ทธ๋ฌ๋ ํฐ k๊ฐ ์ฃผ์ด์ง ๊ฒฝ์ฐ, ์ด Constant warmup๋ง์ผ๋ก๋ ์ต์ ํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ์ ์ถฉ๋ถํ์ง ์๋ค๋ ๊ฒ์ ๋ฐ๊ฒฌํ์์ผ๋ฉฐ, ๋ฎ์ ํ์ต๋ฅ ์์ ๋จ๊ณ์์์ ์ ํ์ ํ๋ จ ์ค๋ฅ๋ฅผ ๊ธ์ฆ์ํฌ ์ ์์⇒ ์ด๋ก ์ธํด ์ฐ๋ฆฌ๋ ๋ค์๊ณผ ๊ฐ์ ์ ์ง์ ์์ (gradual warmup)์ ์ ์
- gradual warmup
- ํ์ต๋ฅ ์ ์์ ๊ฐ์์๋ถํฐ ํฐ ๊ฐ์ผ๋ก ์ ์ง์ ์ผ๋ก ์ฆ๊ฐ์ํค๋ ๋์์ ์ธ warmup
- ํ์ต๋ฅ ์ ๊ฐ์์ค๋ฌ์ด ์ฆ๊ฐ๋ฅผ ํผํ๋ฉฐ, ํ๋ จ ์ด๊ธฐ์ ๊ฑด๊ฐํ ์๋ ด(healthy convergence)์ ํ์ฉํจ.
- ์ค์ ๋ก ๋๊ท๋ชจ ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ kn์ผ๋ก ์์ํ์ฌ ํ์ต๋ฅ ์η\eta์์ ์์ํ์ฌ ๊ฐ ๋ฐ๋ณต์์ ์์๋๋งํผ ์ฆ๊ฐ์์ผ 5 epochs ํ์η^=kη\hat{\eta} = k\eta์ ๋๋ฌํ๋๋ก ํจํฉ๋๋ค (์ ํํ warmup ๊ธฐ๊ฐ์ ๋ํ ๊ฒฐ๊ณผ๋ ๊ฒฌ๊ณ ํจ). warmup ํ์๋ ์๋์ ํ์ต๋ฅ ๋ก ๋์๊ฐ.
2.3 Batch Normalization with large minibatches
- Batch Normalization : ๋ฏธ๋๋ฐฐ์น ์ฐจ์์ ๋ฐ๋ผ ํต๊ณ๋์ ๊ณ์ฐ
- ๊ฐ ์ํ์ ์์ค์ ๋ ๋ฆฝ์ฑ์ ๊นจ๋จ๋ฆผ
- ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ์ ๋ณํ๋ ์ต์ ํ๋๋ ์์ค ํจ์์ ๊ธฐ๋ณธ ์ ์๋ฅผ ๋ณ๊ฒฝ
- ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ๋ฅผ ๋ณ๊ฒฝํ ๋ ์์ค ํจ์๋ฅผ ๋ณด์กดํ๊ธฐ ์ํ shortcut
- ํต์ ์ค๋ฒํค๋๋ฅผ ํผํ๊ธฐ ์ํ ์ค์ฉ์ ์ธ ๊ณ ๋ ค
- ์์ค ํจ์๋ฅผ ๋ณด์กดํ๋ ๋ฐ ํ์์
- BN์ด ์ํ๋๊ณ ํ์ฑํ๊ฐ ์ํ ๊ฐ์ ๊ณ์ฐ๋ ๋๋ ์ด๋ฌํ ๊ฐ์ ์ด ์ฑ๋ฆฝํ์ง ์์
- ๊ฐ์ : per-sample loss์ธl(x,w)l(x, w)์ด ๋ค๋ฅธ ๋ชจ๋ ์ํ๊ณผ ๋ ๋ฆฝ์ ์
- ํฌ๊ธฐ๊ฐ n์ธ ๋จ์ผ ๋ฏธ๋๋ฐฐ์น B์ ์์ค์L(B,w)=1n∑x∈BlB(x,w)L(B, w) = \frac{1}{n}\sum_{x \in B}l_B(x, w)๋ก ๋ํ
- BN์ด ์ ์ฉ๋ ๊ฒฝ์ฐ, ํ๋ จ ์ธํธ๋ ์๋ ํ๋ จ ์ธํธ X์์ ์ถ์ถ๋ ํฌ๊ธฐ๊ฐ n์ธ ๋ชจ๋ ๊ตฌ๋ณ๋๋ ๋ถ๋ถ์งํฉ์ ํฌํจ ⇒XnX^n
- training loss L(w)
L(w)=1โฃXnโฃ∑B∈XnL(B,w)L(w) = \frac{1}{|X^n|}\sum_{B \in X^n}L(B, w) \\ \\- ๋ง์ฝ B๋ฅผXnX^n์ '๋จ์ผ ์ํ(single sample)'๋ก ๋ณด๋ฉด, ๊ฐ ๋จ์ผ ์ํ B์ ์์ค์ด ๋ ๋ฆฝ์ ์ผ๋ก ๊ณ์ฐ๋จ⇒ ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ n์ด BN ํต๊ณ๋ฅผ ๊ณ์ฐํ๋ ๋ฐ ์ค์ํ ๊ตฌ์ฑ ์์์์ ์ ์ํด์ผ ํจ
- ๊ฐ worker์ ๋ฏธ๋๋ฐฐ์น ์ํ ํฌ๊ธฐ n์ด ๋ณ๊ฒฝ๋๋ฉด ์ต์ ํ๋๋ ๊ธฐ๋ณธ ์์ค ํจ์ L์ด ๋ณ๊ฒฝ๋๊ฒ ๋จ
- BN์ด ๋ค๋ฅธ n์ผ๋ก ๊ณ์ฐํ ํ๊ท /๋ถ์ฐ ํต๊ณ๋ ์๋ก ๋ค๋ฅธ ์์ค์ ๋ฌด์์ ๋ณ๋์ ๋ํ๋
- ๋ถ์ฐ๋(๊ทธ๋ฆฌ๊ณ ๋ฉํฐ-GPU) ํ๋ จ
- ๋ง์ผ worker๋น ์ํ ํฌ๊ธฐ n์ด ๊ณ ์ ๋์ด ์๊ณ ์ด ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ๊ฐ kn์ผ ๋
- ๊ฐ ์ํBjB_j๊ฐXnX^n์์ ๋ ๋ฆฝ์ ์ผ๋ก ์ ํ๋ k๊ฐ์ ์ํ ๋ฏธ๋๋ฐฐ์น๋ก ๋ณผ ์ ์์.
- ์ด ๊ด์ ์์ BN ์ค์ ์์, k๊ฐ์ ๋ฏธ๋๋ฐฐ์นBjB_j๋ฅผ ๋ณธ ํ, (3)๊ณผ (4)๋ ๋ค์๊ณผ ๊ฐ์ด ๊ณ์ฐ๋จ
wt+k=wt−η∑j<kโฝL(Bj,wt+j)w_{t+k} = w_t - \eta \sum_{j<k} \bigtriangledown L (B_j, w_{t+j})wt+k^=wt−η^1k∑j<kโฝL(Bj,wt+j)\hat{w_{t+k}} = w_t - \hat{\eta} \frac{1}{k} \sum_{j<k} \bigtriangledown L (B_j, w_{t+j})
- ๋ณธ ์ฐ๊ตฌ์์๋η^=kη\hat{\eta} = k\eta๋ก ์ค์ ํ๊ณ worker ์ k๋ฅผ ๋ณ๊ฒฝํ ๋ worker๋น ์ํ ํฌ๊ธฐ n์ ์ผ์ ํ๊ฒ ์ ์งํจ. ๊ทธ๋ฆฌ๊ณ ๋ค์ํ ๋ฐ์ดํฐ์ ๊ณผ ๋คํธ์ํฌ์์ ์ ๋์ํ n = 32๋ฅผ ์ฌ์ฉ
- ๋ง์ฝ n์ด ์กฐ์ ๋๋ค๋ฉด, ์ด๋ BN์ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก ๊ฐ์ฃผ๋์ด์ผ ํ๋ฉฐ ๋ถ์ฐ ํ๋ จ์ด ์๋.
- ๋ํ BN ํต๊ณ๋ ํต์ ์ ์ค์ด๊ธฐ ์ํด์๋ง์ด ์๋๋ผ ์ต์ ํ๋๋ ๊ธฐ๋ณธ ์์ค ํจ์๋ฅผ ๋์ผํ๊ฒ ์ ์งํ๊ธฐ ์ํด์๋ ๋ชจ๋ worker๋ฅผ ๋์์ผ๋ก ๊ณ์ฐํ๋ฉด ์๋จ
- ๋ง์ผ worker๋น ์ํ ํฌ๊ธฐ n์ด ๊ณ ์ ๋์ด ์๊ณ ์ด ๋ฏธ๋๋ฐฐ์น ํฌ๊ธฐ๊ฐ kn์ผ ๋
- ๋ถ์ฐ ๊ตฌํ ์ค๋ฅ๋ค์ ํ์ดํผํ๋ผ๋ฏธํฐ์ ์ ์๋ฅผ ๋ณ๊ฒฝํ์ฌ ํ๋ จ๋๋ ๋ชจ๋ธ์ ์ค์ฐจ๊ฐ ์์๋ณด๋ค ๋์์ง๊ฒ ํ ์ ์์ผ๋ฉฐ, ์ด๋ฌํ ๋ฌธ์ ๋ค์ ๋ฐ๊ฒฌํ๊ธฐ ์ด๋ ค์ธ ์ ์์ต๋๋ค. ์๋์ ์ฃผ์ฅ๋ค์ ๋ช ํํ์ง๋ง, ๊ธฐ๋ณธ solver๋ฅผ ์ถฉ์คํ ๊ตฌํํ๊ธฐ ์ํด ๋ช ์์ ์ผ๋ก ๊ณ ๋ คํ๋ ๊ฒ์ด ์ค์
Weight decay
- ์์ค ํจ์์ L2 ์ ๊ทํ ํญ์ ๊ธฐ์ธ๊ธฐ์ ๊ฒฐ๊ณผ
- per-sample loss
l(x,w)=λ2โฃโฃwโฃโฃ2+ε(x,w)l(x, w) = \frac{\lambda }{2}||w||^2 + \varepsilon (x, w)
- λ2โฃโฃwโฃโฃ2\frac{\lambda}{2} ||w||^2 : ๊ฐ์ค์น์ ๋ํ ์ํ ๋ ๋ฆฝ์ ์ธ L2 ์ ๊ทํ
- ε(x,w)\varepsilon (x, w) : ํฌ๋ก์ค ์ํธ๋กํผ์ ์ํ ์ข ์์ ์ธ ํญ(sample-dependent term)
- SGD update
wt+1=wt−ηλwt−η1n∑x∈Bโฝε(x,wt)w_{t+1} = w_t - \eta \lambda w_t - \eta\frac{1}{n}\sum_{x \in B} \bigtriangledown \varepsilon (x, w_t)
- ์ค์ ๋ก๋ ์ผ๋ฐ์ ์ผ๋ก ์ญ์ ํ๋ฅผ ํตํด ์ํ ์ข ์์ ์ธ ํญ∑โฝε(x,wt)\sum \bigtriangledown \varepsilon (x, w_t) ๋ง ๊ณ์ฐ๋จ
- ๊ฐ์ค์น ๊ฐ์ ํญλwt\lambda w_t ๋ ๋ณ๋๋ก ๊ณ์ฐ๋์ดε(x,wt)\varepsilon (x, w_t)์ ๊ธฐ์ฌํ ๊ทธ๋ ๋์ธํธ์ ์ถ๊ฐ๋จ
- ๊ฐ์ค์น ๊ฐ์ ํญ์ด ์๋ ๊ฒฝ์ฐ, ํ์ต๋ฅ ์ ์กฐ์ ํ๋ ๋ค์ํ ๋ฐฉ๋ฒ์ด ์์
- ε(x,wt)\varepsilon (x, w_t)ํญ์ ์กฐ์ ํ๋ ๋ฐฉ๋ฒ์ ํฌํจ
- ๊ทธ๋ฌ๋ ์ด๋ ํญ์ ์ฑ๋ฆฝํ์ง ์์.
Momentum correction
- Momentum SGD
- ๋ฐ๋๋ผ SGD์ ๋ํ ํํ ์ฑํ๋๋ modification์ค ํ๋
- ๋ชจ๋ฉํ
SGD์ ์ฐธ์กฐ ๊ตฌํ
ut+1=mut+1nโฝl(x,wt)−−−−(9)u_{t+1} = mu_t + \frac{1}{n} \bigtriangledown l (x, w_t)----(9)wt+1=wt−ηut+1−−−−(9)w_{t+1} = w_t - \eta u_{t+1}----(9)
- m : momentum decay factor
- u : update tensor
- ์ธ๊ธฐ ์๋ ๋ณํ
- ํ์ต๋ฅ η๋ฅผ ์
๋ฐ์ดํธ ํ
์์ ํก์ํ๋ ๊ฒ
vt+1=mut+η1n∑x∈Bโฝl(x,wt)−−−−(10)v_{t+1} = mu_t + \eta \frac{1}{n}\sum_{x \in B} \bigtriangledown l (x, w_t)----(10)wt+1=wt−vt+1−−−−(10)w_{t+1} = w_t - v_{t+1}----(10)
- ๊ณ ์ ๋η\eta์ ๋ํด 2๊ฐ์ง๊ฐ ๋์ผํจ
- u๊ฐ ๊ทธ๋ ์ด๋์ธํธ์๋ง ์์กดํ๊ณ η\eta์๋ ๋ ๋ฆฝ์ ์ธ ๋ฐ๋ฉด, v๋η\eta์ ์ฝํ ์๋ค๋ ์ ์ ์ฃผ๋ชฉํด์ผ ํจ.
- η\eta๊ฐ ๋ณ๊ฒฝ๋ฌ์ ๋ (9)์ ์ฐธ์กฐ๋ณํ๊ณผ ๋์ผ์ฑ์ ์ ์ง๋๊ธฐ ์ํด v์ ์
๋ฐ์ดํธ๋ ๋ค์๊ณผ ๊ฐ์์ ธ์ผ ํจ
vt+1=mηt+1ηt+ηt+11n∑โฝl(x,wt)v_{t+1} = m \frac{\eta_{t+1}}{\eta_t} + \eta_{t+1}\frac{1}{n}\sum \bigtriangledown l (x, w_t)
- ηt+1ηt\frac{\eta_{t+1}}{\eta_{t}} : ๋ชจ๋ฉํ ๋ณด์ (momentum correction)
- ηt+1>>ηt\eta_{t+1} >> \eta_t ์ธ ๊ฒฝ์ฐ ํ๋ จ์ ์์ ํ์ํค๋ ๋ฐ ์ค์ํจ.
- ํ์ต๋ฅ η๋ฅผ ์
๋ฐ์ดํธ ํ
์์ ํก์ํ๋ ๊ฒ
Gradient aggregation
- k ๊ฐ์ worker๋ง๋ค ๊ฐ๊ฐ ํฌ๊ธฐ๊ฐ n์ธ ๋ฏธ๋๋ฐฐ์น๋ฅผ ๊ฐ์ง๊ณ ์๋ ๊ฒฝ์ฐ, (4)์ ๋ฐ๋ผ ๊ทธ๋ ์ด๋์ธํธ ์ง๊ณ๋ kn๊ฐ์ ์ ์ฒด ์์ ์ธํธ์ ๋ํด ๋ค์ ์์ด ์ํ๋์ด์ผ ํจ.
1kn∑j∑x∈Bjl(x,wt)\frac{1}{kn}\sum_j\sum_{x \in B_j}l(x, w_t)
- ์์ค ๋ ์ด์ด
- ์ผ๋ฐ์ ์ผ๋ก ์์ฒด์ ๋ก์ปฌ ์ ๋ ฅ์ ๋ํ ํ๊ท ์์ค์ ๊ณ์ฐํ๋ ๋ฐฉ์์ผ๋ก ๊ตฌํ
- ๊ฐ ์์ ์์ ์์ค∑l(x,wt)/n\sum l(x, w_t)/n์ ๊ณ์ฐํ๋ ๊ฒ๊ณผ ๋์ผํจ
- ์๊ธฐ ๋ด์ฉ์ ๊ณ ๋ คํ๋ฉด ์ฌ๋ฐ๋ฅธ ์ง๊ณ๋ ๋๋ฝ๋ 1/k์์๋ฅผ ๋ณต์ํ๊ธฐ ์ํด k๊ฐ์ ๊ทธ๋ ๋์ธํธ๋ฅผ ํ๊ท ํํด์ผํจ.
- ๊ทธ๋ฌ๋, allreduce [11]์ ๊ฐ์ ํ์ค ํต์ ์์๋ ํ๊ท ์ด ์๋ ํฉ์ ์ํ
- ๋ฐ๋ผ์ 1/k์ค์ผ์ผ๋ง์ ์์ค์ ํก์ํ๋ ๊ฒ์ด ๋ ํจ์จ์ ์ด๋ฉฐ, ์ด ๊ฒฝ์ฐ์๋ ์์ค์ ๋ํ ์ ๋ ฅ์ ๊ทธ๋ ์ด๋์ธํธ๋ง ์ค์ผ์ผ๋งํ๋ฉด ๋๋ฏ๋ก ์ ์ฒด ๊ทธ๋ ์ด๋์ธํธ ๋ฒกํฐ๋ฅผ ์ค์ผ์ผ๋งํ ํ์๊ฐ ์์.Remark 3: Normalize the per-worker loss by total minibatch size kn, not per-worker size n
- ๋ํ ‘k๋ฅผ ์ทจ์’ํ๊ธฐ์ํดη^=η\hat{\eta} = \eta ์ค์ ํ๊ณ ์์ค์ 1/n์ผ๋ก ์ ๊ทํํ๋ ๊ฒ์ ์๋ชป๋ ์ ์์
- ์ด๋ ์๋ชป๋ ๊ฐ์ค์น ๊ฐ์ ๋ก ์ด์ด์ง (← remark 1 ์ฐธ์กฐ)
Data shuffling
- SGD๋ ์ผ๋ฐ์ ์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ๋ณต์์ถ์ถ๋ก ์ํ๋งํ๋ ํ๋ก์ธ์ค๋ก ๋ถ์
- ์ค์ ๋ก ์ผ๋ฐ์ ์ธ SGD ๊ตฌํ์์๋ ๊ฐ SGD ์ํฌํฌ๋ง๋ค ํ๋ จ ์ธํธ๋ฅผ ๋ฌด์์๋ก ์์ด์ฃผ๋ ๊ฒ์ด ํํ๋ฉฐ, ์ด๋ ๋ ๋์ ๊ฒฐ๊ณผ๋ฅผ ์ป์ ์์์.
- ๋ฌด์์ ์
ํ๋ง์ ์ฌ์ฉํ๋ ๊ธฐ์ค์ (K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016)๊ณผ ๊ณต์ ํ ๋น๊ต๋ฅผ ์ ๊ณตํ๊ธฐ ์ํด, k worker์ ์ํด ์ํ๋ ํ ์ํฌํฌ์ ์ํ์ ํ๋ จ ์ธํธ์ ๋จ์ผ ์ผ๊ด๋ ๋ฌด์์ ์๊ธฐ์์ ๊ฐ์ ธ์ค๋๋ก ํจ.
- ์ด๋ฅผ ๋ฌ์ฑํ๊ธฐ ์ํด ๊ฐ ์ํฌํฌ๋ง๋ค k ๋ถ๋ถ์ผ๋ก ๋๋์ด์ง ๋ฌด์์ ์๊ธฐ๋ฅผ ์ฌ์ฉํ๋ฉฐ, ๊ฐ ๋ถ๋ถ์ k worker ์ค ํ๋์ ์ํด ์ฒ๋ฆฌ
- ์ฌ๋ฌ worker์์ ๋ฌด์์ ์ ํ๋ง์ ์ฌ๋ฐ๋ฅด๊ฒ ๊ตฌํํ์ง ์์ผ๋ฉด ๋๋ ทํ ๋ค๋ฅธ ๋์์ ์ ๋ฐ ๊ฐ๋ฅํจ⇒๊ฒฐ๊ณผ์ ๊ฒฐ๋ก ์ ์ค์ผ
Remark 4: Use a single random shuffling of the training data (per epoch) that is divided amongst all k workers
- ํ๋์ Big Basin ์๋ฒ์์์ 8๊ฐ์ GPU๋ฅผ ๋์ด์ ๊ท๋ชจ๋ฅผ ํ์ฅํ๋ ค๋ฉด , ๊ทธ๋ ์ด๋์ธํธ ์ง๊ณ๋ ๋คํธ์ํฌ ์์ ์ฌ๋ฌ ์๋ฒ์ ๊ฑธ์ณ ์ด๋ค์ ธ์ผ ํจ. ๊ฑฐ์ ์๋ฒฝํ ์ ํ ํ์ฅ์ ํ์ฉํ๊ธฐ ์ํด์๋ ์ง๊ณ๊ฐ ์ญ์ ํ์ ๋ณ๋ ฌ๋ก ์ํ๋์ด์ผ ํฉ๋๋ค. ์ด๋ ์ธต ๊ฐ์ ๊ทธ๋ ์ด๋์ธํธ ๊ฐ์ ๋ฐ์ดํฐ ์์กด์ฑ์ด ์๊ธฐ ๋๋ฌธ์ ๊ฐ๋ฅํฉ๋๋ค. ๋ฐ๋ผ์ ํ ์ธต์ ๊ทธ๋ ์ด๋์ธํธ๊ฐ ๊ณ์ฐ๋๋ฉด ์ฆ์ ํด๋น ์ธต์ ๋ํ ๊ทธ๋ ์ด๋์ธํธ๊ฐ ์์ ์๋ค ๊ฐ์ ์ง๊ณ๋๊ณ , ๋์์ ๋ค์ ์ธต์ ๋ํ ๊ทธ๋ ์ด๋์ธํธ ๊ณ์ฐ์ด ๊ณ์๋ฉ๋๋ค. ๋ค์์์๋ ์ด์ ๋ํ ์์ธํ ๋ด์ฉ์ ์ค๋ช ํ๊ฒ ์ต๋๋ค.
4.1 Gradient Aggregation
- ๊ฐ ๊ทธ๋ ์ด๋์ธํธ์ ๋ํด ์ง๊ณ๋ MPI ์งํฉ ์ฐ์ฐ(MPI Allreduce)[11]์ ์ ์ฌํ allreduce ์์ ์ ์ฌ์ฉํ์ฌ ์ํ๋ฉ๋๋ค. Allreduce๊ฐ ์์๋๊ธฐ ์ ์ ๊ฐ GPU๋ ๋ก์ปฌ๋ก ๊ณ์ฐ๋ ๊ทธ๋ ์ด๋์ธํธ๋ฅผ ๊ฐ๊ณ ์๊ณ , allreduce๊ฐ ์๋ฃ๋๋ฉด ๊ฐ GPU๋ ๋ชจ๋ k๊ฐ์ ๊ทธ๋ ์ด๋์ธํธ์ ํฉ์ ๊ฐ์ต๋๋ค. ๋งค๊ฐ๋ณ์์ ์๊ฐ ์ฆ๊ฐํ๊ณ GPU์ ๊ณ์ฐ ์ฑ๋ฅ์ด ํฅ์๋จ์ ๋ฐ๋ผ ์ง๊ณ ๋น์ฉ์ backprop ๋จ๊ณ์์ ์จ๊ธฐ๊ธฐ๊ฐ ๋ ์ด๋ ค์์ง๋๋ค. ์ด๋ฌํ ํจ๊ณผ๋ฅผ ๊ทน๋ณตํ๊ธฐ ์ํ ํ๋ จ ๊ธฐ์ ์ ์ด ์์ ์ ๋ฒ์๋ฅผ ๋ฒ์ด๋ฉ๋๋ค (์: ์์ํ๋ ๊ทธ๋ ์ด๋์ธํธ(quantized gradient) [18], ๋ธ๋ก-๋ชจ๋ฉํ SGD [6]). ๊ทธ๋ฌ๋ ์ด ์์ ์ ๊ท๋ชจ์์๋ ์ต์ ํ๋ allreduce ๊ตฌํ์ ์ฌ์ฉํ์ฌ ๊ฑฐ์ ์ ํ์ ์ธ SGD ์ค์ผ์ผ๋ง์ ๋ฌ์ฑํ ์ ์์ด, ์ง๋จ ํต์ ์ด ๋ณ๋ชฉ์ด ๋์ง ์์์ต๋๋ค.
- allreduce ๊ตฌํ 3๋จ๊ณ : ์๋ฒ ๋ด๋ถ ๋ฐ ์๋ฒ ๊ฐ ํต์ ์ ์ํด์.
๐ก1. ์๋ฒ ๋ด์ 8๊ฐ GPU์์ ๊ฐ๊ฐ์ ๋ฒํผ๊ฐ ๊ฐ ์๋ฒ์ ๋ํด ํ๋์ ๋จ์ผ ๋ฒํผ๋ก ํฉ์ฐ
- ๊ฒฐ๊ณผ ๋ฒํผ๋ ๋ชจ๋ ์๋ฒ ๊ฐ์ ๊ณต์ ๋์ด ํฉ์ฐ
- ๊ฒฐ๊ณผ๊ฐ ๊ฐ GPU๋ก ๋ธ๋ก๋์บ์คํธ๋จ
- ์๋ฒ ๊ฐ allreduce๋ฅผ ์ํด ๋์ญํญ ์ ํ ์๋๋ฆฌ์ค์ ๋ํ ๋ ๊ฐ์ง ์ต๊ณ ์ ์๊ณ ๋ฆฌ์ฆ์ ๊ตฌํ
- ์ฌ๊ท์ ์ธ ๋ฐ๊ฐ ๋ฐ ๋ฐฐ๊ฐ ์๊ณ ๋ฆฌ์ฆ
- 2 log2(p) ํต์ ๋จ๊ณ๋ก ์ด๋ฃจ์ด์ง
- reduce-scatter ์งํฉ์ ์ด์ด allgather๋ก ์ด๋ฃจ์ด์ ธ ์์
- process
- ์๋ฒ๋ ์์ผ๋ก ํต์ ํ๋ฉฐ (๋ญํฌ 0์ 1๊ณผ, 2๋ 3๊ณผ ๋ฑ๋ฑ), ์ ๋ ฅ ๋ฒํผ์ ๋ค๋ฅธ ๋ฐ์ชฝ์ ๋ํด ๋ณด๋ด๊ณ ๋ฐ์ (๋ญํฌ 0์ ๋ฒํผ์ ๋ ๋ฒ์งธ ์ ๋ฐ์ 1์๊ฒ ๋ณด๋ด๊ณ 1๋ก๋ถํฐ ๋ฒํผ์ ์ฒซ ๋ฒ์งธ ์ ๋ฐ์ ๋ฐ์)
- ๋ค์ ๋จ๊ณ๋ก ์งํํ๊ธฐ ์ ์ ์์ ๋ ๋ฐ์ดํฐ์ ๋ํ ์ถ์๊ฐ ์ํ๋๊ณ , ๋ค์ ๋จ๊ณ์์๋ ๋ชฉ์ ์ง ๋ญํฌ๊น์ง์ ๊ฑฐ๋ฆฌ๊ฐ ๋ ๋ฐฐ๋ก ๋์ด๋๋ฉด์ ๋ณด๋ด๊ณ ๋ฐ์ ๋ฐ์ดํฐ๊ฐ ์ ๋ฐ์ผ๋ก ์ค์ด๋ค์
- reduce-scatter ๋จ๊ณ๊ฐ ์๋ฃ๋๋ฉด ๊ฐ ์๋ฒ์๋ ์ต์ข ์ถ์๋ ๋ฒกํฐ์ ์ผ๋ถ๊ฐ ์์.
- allgather ๋จ๊ณ
- reduce-scatter์์์ ํต์ ํจํด์ ์ญ์ผ๋ก ์ถ์ ํ์ฌ ์ต์ข ์ถ์๋ ๋ฒกํฐ์ ์ผ๋ถ๋ฅผ ๊ฐ๋จํ ์ฐ๊ฒฐ
- ๊ฐ ์๋ฒ์์ reduce-scatter์์ ๋ณด๋ด๊ณ ์๋ ๋ฒํผ์ ์ผ๋ถ๊ฐ allgather์์ ์์ ๋๊ณ , ๋ฐ๋ ๋ถ๋ถ์ ์ด์ ๋ณด๋ด์ง
- ์๋ฒ์ ์๊ฐ 2์ ๊ฑฐ๋ญ์ ๊ณฑ์ด ์๋ ๊ฒฝ์ฐ์ ๋์ํ๊ธฐ ์ํด ์ด์ง ๋ธ๋ก ์๊ณ ๋ฆฌ์ฆ [30]์ ์ฌ์ฉ
- ์ด์ง ๋ธ๋ก ์๊ณ ๋ฆฌ์ฆ : ์๋ฒ๊ฐ 2์ ๊ฑฐ๋ญ์ ๊ณฑ ๋ธ๋ก์ผ๋ก ๋ถํ ๋๊ณ ๋ ๊ฐ์ ์ถ๊ฐ ํต์ ๋จ๊ณ๊ฐ ์ฌ์ฉ๋๋ ๋ฐ๊ฐ/๋ฐฐ๊ฐ ์๊ณ ๋ฆฌ์ฆ์ ์ผ๋ฐํ๋ ๋ฒ์
- ์๋ ๋ธ๋ก ๋ด๋ถ reduce-scatter ํ์ ๋ธ๋ก ๋ด๋ถ allgather ์ ์ ๊ฐ๊ฐ ํ ๋ฒ ์ฌ์ฉ๋จ
- ๊ฑฐ๋ญ์ ๊ณฑ์ด ์๋ ๊ฒฝ์ฐ ์ผ๋ถ ๋ถํ ๋ถ๊ท ํ์ด ๊ฑฐ๋ญ์ ๊ณฑ๊ณผ ๋น๊ตํ์ฌ ๋ฐ์ํ์ง๋ง, ํ ๋ ผ๋ฌธ์์๋ ์ฑ๋ฅ์ ํ๋ฅผ ๊ด์ฐฐํ์ง ๋ชปํจ.
- process
- ๋ฒํท ์๊ณ ๋ฆฌ์ฆ (๋ง ์๊ณ ๋ฆฌ์ฆ์ด๋ผ๊ณ ๋ ํจ)
- 2(p−1) ๋จ๊ณ๋ก ์ด๋ฃจ์ด์ง.
- ๋ฐ๊ฐ/๋ฐฐ๊ฐ ์๊ณ ๋ฆฌ์ฆ์ด ๋๊ฐ latency-limited ์๋๋ฆฌ์ค์์ ๋ ๋น ๋ฅด๊ฒ ๋์(์ฆ, ์์ ๋ฒํผ ํฌ๊ธฐ ๋ฐ/๋๋ ํฐ ์๋ฒ ์์ ๊ฒฝ์ฐ ,์ฝ 3).
- ์ฌ๊ท์ ์ธ ๋ฐ๊ฐ ๋ฐ ๋ฐฐ๊ฐ ์๊ณ ๋ฆฌ์ฆ
4.2 software
- ํต์ ์์ง์ ์ํ allreduce ์๊ณ ๋ฆฌ์ฆ์ Gloo ๊นํ๋ธ์ ์์.
- ๋ณ๋ ฌ๋ก ์ฌ๋ฌ allreduce ์ธ์คํด์ค๋ฅผ ์คํํ๊ธฐ ์ํด ์ถ๊ฐ ๋๊ธฐํ๊ฐ ํ์ํ์ง ์๋ ์ฌ๋ฌ ํต์ context๋ฅผ ์ง์
- ๋ก์ปฌ ๋ฆฌ๋์ ๋ฐ ๋ธ๋ก๋์บ์คํธ (๋จ๊ณ (1) ๋ฐ (3)๋ก ์ค๋ช ๋จ)์ ๊ฐ๋ฅํ ๊ฒฝ์ฐ ์๋ฒ ๊ฐ allreduce์ ํ์ดํ๋ผ์ธํ๋จ
- Caffe2
- ํ๋ จ ๋ฐ๋ณต์ ๋ํ๋ด๋ ์ปดํจํธ ๊ทธ๋ํ์ ๋ฉํฐ์ค๋ ๋ ์คํ์ ์ง์
- ์๋ธ๊ทธ๋ํ ๊ฐ์ ๋ฐ์ดํฐ ์์กด์ฑ์ด ์๋ ๊ฒฝ์ฐ ์ฌ๋ฌ ์ค๋ ๋๊ฐ ๊ทธ ์๋ธ๊ทธ๋ํ๋ฅผ ๋ณ๋ ฌ๋ก ์คํ๊ฐ๋ฅ
- ์ด๋ฅผ backprop์ ์ ์ฉํ๋ฉด ๋ก์ปฌ ๊ทธ๋ ์ด๋์ธํธ๊ฐ ์์ฐจ์ ์ผ๋ก ๊ณ์ฐ๋ ์ ์๊ณ , allreduce๋ ๊ฐ์ค์น ์
๋ฐ์ดํธ์ ๊ด๋ จ์ด ์์
- ์ด๋ backprop ์ค์ ์คํ ๊ฐ๋ฅํ ์๋ธ๊ทธ๋ํ ์งํฉ์ด ์คํ ๊ฐ๋ฅํ ์๋ธ๊ทธ๋ํ๋ฅผ ์คํํ๋ ์๋๋ณด๋ค ๋ ๋นจ๋ฆฌ ์ฆ๊ฐํ ์ ์์์ ์๋ฏธ
- allreduce๋ฅผ ํฌํจํ๋ ์๋ธ๊ทธ๋ํ์ ๊ฒฝ์ฐ ๋ชจ๋ ์๋ฒ๊ฐ ์คํ ๊ฐ๋ฅํ ์๋ธ๊ทธ๋ํ ์งํฉ์์ ๋์ผํ ์๋ธ๊ทธ๋ํ๋ฅผ ์คํํ๋๋ก ์ ํํด์ผ ํจ
- ๊ทธ๋ ์ง ์์ผ๋ฉด ์๋ฒ๊ฐ ์๋ก ๊ต์ฐจํ์ง ์๋ ์๋ธ๊ทธ๋ํ ์งํฉ์ ์คํํ๋ ค๊ณ ํ ๋ ๋ถ์ฐ ๋ฐ๋๋ฝ์ด ๋ฐ์ํ ์ํ์ด ์์.
- allreduce๊ฐ ์งํฉ ์ฐ์ฐ์ด๊ธฐ ๋๋ฌธ์ ์๋ฒ๋ ํ์์์ ๋ ๊ฑฐ์
- ์ฌ๋ฐ๋ฅธ ์คํ์ ๋ณด์ฅํ๊ธฐ ์ํด ์ด๋ฌํ ์๋ธ๊ทธ๋ํ์ ๋ํ ๋ถ๋ถ์ ์ธ ์์๋ฅผ ๋ถ์ฌํด์ผํจ.
- ์ํ์ ์ธ ์ ์ด ์ ๋ ฅ์ ์ฌ์ฉํ์ฌ ๊ตฌํ๋จ
- n๋ฒ์งธ allreduce์ ์๋ฃ๊ฐ (n + c)๋ฒ์งธ allreduce์ ์คํ์ ๋ธ๋ก ํด์ ํ๊ฒ ๋จ
- ์ฌ๊ธฐ์ c๋ ์ต๋ ๋์ allreduce ์คํ ํ์
- ์ด ์ซ์๋ ์ ์ฒด ์ปดํจํธ ๊ทธ๋ํ๋ฅผ ์คํํ๋ ๋ฐ ์ฌ์ฉ๋๋ ์ค๋ ๋ ์๋ณด๋ค ๋ฎ๊ฒ ์ ํํด์ผ ํจ.