๊ด€๋ฆฌ ๋ฉ”๋‰ด

EunGyeongKim

[๋…ผ๋ฌธ๋ฆฌ์„œ์น˜] Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour ๋ณธ๋ฌธ

ML & DL/๋”ฅ๋Ÿฌ๋‹

[๋…ผ๋ฌธ๋ฆฌ์„œ์น˜] Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour

EunGyeongKim 2024. 2. 5. 13:17

๐Ÿ’ก
imageNet์˜ dataset์—์„œ ํฐ ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋Š” ์ตœ์ ํ™”๋ฅผ ์‹œํ‚ค๋Š”๋ฐ ์–ด๋ ต๋‹ค๋Š”๊ฒƒ์„ ๋ณด์—ฌ์คŒ(ํ•˜์ง€๋งŒ ์ด ๋ฌธ์ œ๊ฐ€ ํ•ด๊ฒฐ๋˜๋ฉด ํ›ˆ๋ จ๋ชจ๋ธ์€ ์ข‹์€ ์ผ๋ฐ˜ํ™”๋ฅผ ๊ฐ€์ง). ์ด ๋ฌธ์ œ๋ฅผ ์œ„ํ•ด ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํฌ๊ธฐ์˜ ํ•จ์ˆ˜๋กœ ํ•™์Šต์†๋„๋ฅผ ์กฐ์ •ํ•˜๊ธฐ ์œ„ํ•œ ์ดˆ๋งค๊ฐœ๋ณ€์ˆ˜ ์—†๋Š” ์„ ํ˜• ํ™•์žฅ ๊ทœ์น™์„ ์ฑ„ํƒํ•จ. ๋˜ํ•œ ํ›ˆ๋ จ ์ดˆ๊ธฐ์— ์ตœ์ ํ™” ๋ฌธ์ œ๋ฅผ ๊ทน๋ณตํ•˜๋Š” ์ƒˆ๋กœ์šด ์›Œ๋ฐ์—… ๋ฐฉ์‹ ๊ฐœ๋ฐœ.

This paper’s goal

  • ๋ถ„์‚ฐ์‹ ๋น„๋™๊ธฐ์  SGD(stochastic gradient descent)๊ฐ€ ๋Œ€๊ทœ๋ชจ ํŠธ๋ ˆ์ด๋‹์— ์ ํ•ฉํ•˜๋‹ค๋Š” ๊ฒƒ์„ ์ฆ๋ช…ํ•˜๊ณ  ์‹ค์šฉ์ ์ธ ๊ฐ€์ด๋“œ ์ „๋‹ฌํ•˜๊ธฐ

 

ํฐ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋ฅผ ๋‹ค๋ฃจ๊ธฐ ์œ„ํ•ด ํ•„์š”ํ•œ ๊ฒƒ

  • hyper-parameter-free linear scaling rule๋ฅผ ์ด์šฉํ•˜์—ฌ learning rate ์กฐ์ •ํ•˜๊ธฐ
    • ์ด ๊ฐ€์ด๋“œ๋ผ์ธ์€ earlier worker์—์„œ ์„ค๋ฆฝ๋จ
    • ๊ฒฝํ—˜์  ํ•œ๊ณ„๋Š” ์ž˜ ์ดํ•ด๋˜์ง€ ์•Š์œผ๋ฉฐ, ๋น„๊ณต์‹์ ์ด๊ธฐ ๋•Œ๋ฌธ์— research community์— ์ž˜ ์•Œ๋ ค์ง€์ง€ ์•Š์Œ
    • ์ด ๋ฒ•์น™์„ ์„ฑ๊ณต์ ์œผ๋กœ ์ ์šฉํ•˜๊ธฐ ์œ„ํ•ด ์ƒˆ๋กœ์šด ์ค€๋น„์ „๋žต ์ œ์‹œ
      • ์ค€๋น„์ „๋žต = ์ดˆ๊ธฐ ์ตœ์ ํ™”๋ฌธ์ œ๋ฅผ ๊ทน๋ณตํ•˜๊ธฐ ์œ„ํ•ด ๋‚ฎ์€ learning rate ์‚ฌ์šฉํ•˜๊ธฐ
  • ์†์‹ค L(w)์‹
    L(w)=1โˆฃXโˆฃ∑x∈Xl(x,w)L(w) = \frac{1}{|X|}\sum_{x\in X}l(x, w)
    • w : ๋„คํŠธ์›Œํฌ์˜ ๊ฐ€์ค‘์น˜(weight of a network)
    • X : ๋ ˆ์ด๋ธ”์ด ์ง€์ •๋œ training set
    • l(x, w) : ์ƒ˜ํ”Œx∈Xx \in X๊ณผ ๋ผ๋ฒจ y์—์„œ ๊ณ„์‚ฐ๋œ ์†์‹ค
    • l : ๋ถ„๋ฅ˜์†์‹ค(๊ต์ฐจ์—”ํŠธ๋กœํ”ผ)์™€ w์—๋Œ€ํ•œ ์ •๊ทœ์†์‹ค์˜ ํ•ฉ
  • Minibatch stochastic gradient decent๋Š” ๋ฏธ๋‹ˆ๋ฐฐ์น˜์—์„œ ์ž‘๋™ํ•จ. ๋ณดํ†ต ๊ฐ„๋‹จํ•˜๊ฒŒ SGD๋ผ๊ณ  ํ•˜๋ฉฐ, ๋‹ค์Œ ์—…๋ฐ์ดํŠธ๋ฅผ ์ˆ˜ํ–‰ํ•จ
    wt+1=wt−η1n∑x∈Bโ–ฝ(x,wt)w_{t+1} = w_t - \eta\frac{1}{n}\sum_{x \in B} \bigtriangledown (x, w_t)
    • BB : X์—์„œ ์ƒ˜ํ”Œ๋ง๋œ ๋ฏธ๋‹ˆ๋ฐฐ์น˜์ž„
    • n=โˆฃBโˆฃn= |B | : ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ
    • η\eta : learning rate
    • t : iteration index
  • ์ด ๋…ผ๋ฌธ์—์„œ๋Š” ์‹ค์ œ๋กœ momentum SGD๋ฅผ ์‚ฌ์šฉํ•จ.

 

2.1 Learning rates for Large Minibatches

  • ๋ชฉํ‘œ
    • ์ž‘์€ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ๋Œ€์‹  ํฐ ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋ฅผ ์“ฐ๋ฉด์„œ training๊ณผ ์ผ๋ฐ˜ํ™” ์ •ํ™•๋„ ์œ ์ง€ํ•˜๊ธฐ⇒ worker๋ณ„ ์ž‘์—…๋Ÿ‰์„ ์ค„์ด๊ฑฐ๋‚˜ ๋ชจ๋ธ ์ •ํ™•๋„๋ฅผ ์ €ํ•˜์‹œํ‚ค์ง€ ์•Š๊ณ  ๊ฐ„๋‹จํ•œ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์—ฌ๋Ÿฌ worker๋กœ ํ™•์žฅํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋ถ„์‚ฐํ•™์Šต์—์„œ ํŠนํžˆ ์ค‘์š”!
    • (worker์™€ GPU๋ฅผ ๊ฐ™์€ ์˜๋ฏธ๋กœ ์‚ฌ์šฉ)

 

  • learning rate scaling rule์ด ๊ด‘๋ฒ”์œ„ํ•œ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํฌ๊ธฐ์— ๋†€๋ผ์šธ์ •๋„๋กœ ํšจ๊ณผ์ ์ž„linear scaling rule : ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ์— k๋ฅผ ๊ณฑํ•˜๋ฉด, ํ•™์Šต๋ฅ ์— k๋ฅผ ๊ณฑํ•˜๊ธฐ
    • ๋‹ค๋ฅธ hyper-parameter(weight decay)๋Š” ๋ฐ”๊พธ์ง€ ์•Š๊ธฐ
    • linear scaling rule์€ ์ž‘์€ ๋ฏธ๋‹ˆ๋ฐฐ์น˜์™€ ํฐ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ์‚ฌ์ด์— ์ •ํ™•๋„๋ฅผ ์ผ์น˜์‹œํ‚ค๋Š”๋ฐ ๋„์›€์„ ์คŒ. ๋˜, training curve๋„ ๋งค์น˜ ์‹œํ‚ด⇒ ์‹คํ—˜์ „์— ๋น ๋ฅด๊ฒŒ ๋น„๊ตํ•˜๊ณ  ๋””๋ฒ„๊น…ํ•  ์ˆ˜ ์žˆ๊ฒŒ๋จ
    •  
  • interpretation
    • leaner scaling rule๊ณผ ์™œ ํšจ๊ณผ์ ์ธ์ง€ ์„ค๋ช…ํ•˜๊ฒ ์Œ→ iteration์ด t, weight๊ฐ€ w์ธ ๋„คํŠธ์›Œํฌ,0≤j<k0 \leq j <k์— ๋Œ€ํ•ด k๊ฐœ์˜ ๋ฏธ๋‹ˆ๋ฐฐ์น˜BjB_j์˜ ์‹œํ€€์Šค๋ฅผ ๊ณ ๋ คํ•ด์•ผํ•จ
      • k SGD iteration(์ž‘์€ ๋ฏธ๋‹ˆ๋ฐฐ์น˜BjB_j, learning rate๊ฐ€η\eta)์™€ single iteration(์‚ฌ์ด์ฆˆ๊ฐ€ kn ์ธ ํฐ ๋ฏธ๋‹ˆ๋งค์น˜∪jBj\cup_j B_j์™€ ํ•™์Šต๋ฅ η^\hat{\eta})์„ ์‹คํ–‰ ํšจ๊ณผ๋ฅผ ๋น„๊ต
        • learning rate๊ฐ€η\eta, ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๊ฐ€ n์ธ SGD์˜ k iteration
          wt+k=wt−η1n∑j<k∑x∈Bjโ–ฝl(x,wt+j)−−−−(3)w_{t+k} = w_t -\eta \frac{1}{n} \sum_{j<k}\sum_{x \in B_j} \bigtriangledown l (x, w_{t+j})---- (3)
        • learning rate๊ฐ€η^\hat{\eta}, ์‚ฌ์ด์ฆˆ๊ฐ€ kn ํฐ ๋ฏธ๋‹ˆ๋ฐฐ์น˜∪jBj\cup _j B_j๋ฅผ ์‚ฌ์šฉํ•œ ๋‹จ์ผ ๋‹จ๊ณ„๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด
        w^t+1=wt−η^1kn∑j<k∑x∈Bjโ–ฝl(x,wt)−−−−(4)\hat{w}_{t+1} = w_t - \hat{\eta}\frac{1}{kn}\sum_{j<k}\sum_{x \in B_j}\bigtriangledown l (x, w_t)---- (4)
      • ์˜ˆ์ƒํ•œ๊ฒƒ์ฒ˜๋Ÿผ ์—…๋ฐ์ดํŠธ ๊ฐ’์€ ์„œ๋กœ ๋‹ค๋ฅด๋ฉฐ,w^t+1=wt+k\hat{w}_{t+1} = w_{t+k}์ผ ํ™•๋ฅ ์€ ์—†์Œ.
      • ํ•˜์ง€๋งŒ, j<k์— ๋Œ€ํ•ดโ–ฝl(x,wt)≈โ–ฝl(x,wt+j)\bigtriangledown l (x, w_t) \approx \bigtriangledown l(x, w_{t+j})๋Š” ๊ฐ€์ •ํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฝ์šฐη^=kη\hat{\eta} = k\eta๋กœ ์„ค์ •ํ•  ๋•Œw^t+1≈wt+k\hat{w}_{t+1} \approx w_{t+k}๊ฐ€ ์ƒ์„ฑ๋˜๊ณ , ์ž‘์€ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ SGD์™€ ํฐ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ SGD๋Š” ์„œ๋กœ ์œ ์‚ฌํ•จ.
      • ๋˜ํ•œ, ๊ฐ•๋ ฅํ•œ ๊ฐ€์ •์ž„์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ  ์ด๊ฒƒ์ด ์‚ฌ์‹ค์ด๋ผ๋ฉด,η^=kη\hat{\eta} = k\eta๋ฅผ ์„ค์ •ํ•œ ๊ฒฝ์šฐ์—๋งŒ ๋‘ ์—…๋ฐ์ดํŠธ๊ฐ€ ์œ ์‚ฌํ•˜๋‹ค๊ณ  ๊ฐ•์กฐํ•จ.
    • ์œ„ ํ•ด์„์€ linear scaling rule์ด ์ ์šฉ๋˜๊ธฐ๋ฅผ ๋ฐ”๋ผ๋Š” ํ•œ๊ฐ€์ง€ ๊ฒฝ์šฐ์— ๋Œ€ํ•œ ์ง๊ด€์„ ์ œ๊ณตํ•จ.η^=kη\hat{\eta} = k \eta(๋ฐ ์ค€๋น„)์ธ ์‹คํ—˜์—์„œ ์ž‘์€ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ SGD์™€ ํฐ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ SGD๋Š” ๋ชจ๋ธ์—์„œ์˜ ๊ฐ™์€ ๋งˆ์ง€๋ง‰ ์ •ํ™•๋„๋ฟ๋งŒ ์•„๋‹ˆ๋ผ training curve๋˜ํ™˜ ๊ฝค ๋งค์น˜๋จ. ์‹คํ—˜๊ฒฐ๊ณผ๋Š” ์œ„ ๊ทผ์‚ฌ์น˜๊ฐ€ ๋Œ€๊ทœ๋ชจ ์‹ค์ œ ๋ฐ์ดํ„ฐ์—์„œ ์œ ํšจํ•  ์ˆ˜ ์žˆ๋‹ค๊ณ  ์ œ์•ˆํ•œ๋‹ค.
    • ๊ทธ๋Ÿฌ๋‚˜ ์กฐ๊ฑดโ–ฝl(s,wt)≈โ–ฝl(x,wt+j)\bigtriangledown l (s, w_t) \approx \bigtriangledown l (x, w_{t+j})๊ฐ€ ์œ ์ง€๋˜์ง€ ์•Š๋Š” ๋‘๊ฐ€์ง€ ๊ฒฝ์šฐ๊ฐ€ ์žˆ์Œ
      • ๋„คํŠธ์›Œํฌ๊ฐ€ ๋น ๋ฅด๊ฒŒ ๋ฐ”๋€”๋•Œ์˜ ์ดˆ๊ธฐ ํ›ˆ๋ จ(initial training)⇒ 2.2์˜ ์ค€๋น„๋‹จ๊ณ„๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•ด ํ•ด๊ฒฐ
      • ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํฌ๊ธฐ๊ฐ€ ๋ฌดํ•œ์ • ํ™•์žฅ๋  ์ˆ˜ ์—†์Œ.
        • ๊ฒฐ๊ณผ๋Š” ๋‹ค์–‘ํ•œ ๋ฒ”์œ„์˜ ํฌ๊ธฐ์—์„œ ์•ˆ์ •์ ์ด์ง€๋งŒ, ํŠน์ • ์ง€์ ์„ ๋„˜์œผ๋ฉด ์ •ํ™•๋„๊ฐ€ ๋น ๋ฅด๊ฒŒ ๋–จ์–ด์ง
        • ์ด ์ง€์ ์€ imageNet ์‹คํ—˜์—์„œ 8k๋ณด๋‹ค ํผ
  • Discussion
    • ์ƒ๊ธฐ linear scaling rule์€ Krizhevsky์— ์˜ํ•ด ์ฑ„ํƒ๋จ. ํ•˜์ง€๋งŒ Krizhevsky๋Š” ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๊ฐ€ 128์—์„œ 1024๋กœ ์ฆ๊ฐ€ํ•  ๋•Œ ์—๋Ÿฌ๊ฐ€ 1%๊ฐ€ ์ฆ๊ฐ€ํ•œ๋‹ค๊ณ  ๋ณด๊ณ ํ•จ. ๋ฐ˜๋ฉด์— ์šฐ๋ฆฌ๋Š” ํ›จ์”ฌ ๋” ๊ด‘๋ฒ”์œ„ํ•œ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ ๋ฒ”์œ„์—์„œ ์–ด๋–ป๊ฒŒ ์ •ํ™•๋„๋ฅผ ์œ ์ง€ํ•˜๋Š”์ง€ ๋ณด์—ฌ์คŒ.
    • Chen์€ ์ˆ˜๋งŽ์€ ๋ถ„์‚ฐ SGD ๋ณ€ํ™”์˜ ๋น„๊ต๋ฅผ ์ œ์‹œํ•จ. ๊ทธ๋“ค์€ linear scaling rule์„ ์‚ฌ์šฉํ–ˆ์ง€๋งŒ, ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ๊ธฐ์ค€์ ์„ ์„ค์ •ํ•˜์ง€ ์•Š์Œ.
    • Li๋Š” ์ˆ˜๋ ดํ›„ ์ •ํ™•๋„ ์†์‹ค ์—†์ด ๋ฏธ๋‹ˆ๋ฐฐ์น˜๊ฐ€ 5210๊นŒ์ง€์ธ ๋ถ„์‚ฐ imageNet ํ›ˆ๋ จ์„ ๋ณด์—ฌ์คŒ. ํ•˜์ง€๋งŒ ์šฐ๋ฆฌ๊ฐ€ ํ•ต์‹ฌ์ ์œผ๋กœ ๊ธฐ์ˆ  ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ์‚ฌ์ด์ฆˆ๋กœ ํ•™์Šต๋ฅ  ์กฐ์ •์„ ์œ„ํ•ด hyper-parameter search rule์„ ๋ณด์—ฌ์ฃผ์ง€ ๋ชปํ•จ.
    • ์ตœ๊ทผ ์—ฐ๊ตฌ์—์„œ, Bottou ๋“ฑ [4] (§4.2)์€ ๋ฏธ๋‹ˆ๋ฐฐ์น˜์˜ ์ด๋ก ์  ์žฅ๋‹จ์ ์„ ๊ฒ€ํ† ํ•˜๊ณ  ์„ ํ˜• ์Šค์ผ€์ผ๋ง ๊ทœ์น™์— ๋”ฐ๋ผ solver๊ฐ€ ๋ณธ ์˜ˆ์ œ ์ˆ˜์˜ ํ•จ์ˆ˜๋กœ ๋™์ผํ•œ ํ›ˆ๋ จ ๊ณก์„ ์„ ๋”ฐ๋ฅด๋Š”๊ฒƒ์„ ๋ณด์—ฌ์คŒ. ํ•™์Šต๋ฅ ์€ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํฌ๊ธฐ์™€ ๋…๋ฆฝ์ ์ธ ์ตœ๋Œ€ ์†๋„๋ฅผ ์ดˆ๊ณผํ•ด์„œ๋Š” ์•ˆ ๋˜๋ฉฐ(๋”ฐ๋ผ์„œ ์›œ์—…์ด ์ •๋‹นํ™”๋จ), ์ „๋ก€ ์—†๋Š” ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํฌ๊ธฐ๋กœ ์ด๋Ÿฌํ•œ ์ด๋ก ์„ ์‹คํ—˜์ ์œผ๋กœ ๊ฒ€์ฆํ•จ.

2.2 warmup

  • ๋…ผ์˜ํ•œ ๋Œ€๋กœ, ๋Œ€๊ทœ๋ชจ ๋ฏธ๋‹ˆ๋ฐฐ์น˜(์˜ˆ: 8k)์˜ ๊ฒฝ์šฐ, ์‹ ๊ฒฝ๋ง์ด ๋น ๋ฅด๊ฒŒ ๋ณ€ํ•  ๋•Œ ์„ ํ˜• ์Šค์ผ€์ผ๋ง ๊ทœ์น™์ด ๋ถ•๊ดด๋จ. ์ด๋Š” ํ›ˆ๋ จ ์ดˆ๊ธฐ์— ํ”ํžˆ ๋ฐœ์ƒํ•จ→ ์ด ๋ฌธ์ œ๋ฅผ ์ ์ ˆํžˆ ์„ค๊ณ„๋œ ์›œ์—… [16]์— ์˜ํ•ด ์™„ํ™”์‹œํ‚ฌ ์ˆ˜ ์žˆ์Œ.
    • ํ›ˆ๋ จ ์‹œ์ž‘ ์‹œ์— ๋œ ๊ณต๊ฒฉ์ ์ธ ํ•™์Šต๋ฅ ์„ ์‚ฌ์šฉํ•˜๋Š” ์ „๋žต
  • Constant warmup
    • [16]์—์„œ ์ œ์‹œ๋œ ์›œ์—… ์ „๋žต์€ ํ›ˆ๋ จ์˜ ์ฒ˜์Œ ๋ช‡ ์—ํฌํฌ ๋™์•ˆ ๋‚ฎ์€ ๊ณ ์ • ํ•™์Šต๋ฅ ์„ ์‚ฌ์šฉ
    • ์‚ฌ์ „ ํ›ˆ๋ จ๋œ ๋ ˆ์ด์–ด๋ฅผ ์ƒˆ๋กœ ์ดˆ๊ธฐํ™”๋œ ๋ ˆ์ด์–ด์™€ ํ•จ๊ป˜ ์„ธ๋ฐ€ํ•˜๊ฒŒ ์กฐ์ •ํ•˜๋Š” ๊ฐ์ฒด ๊ฒ€์ถœ ๋ฐ ์„ธ๊ทธ๋ฉ˜ํ…Œ์ด์…˜ ๋ฐฉ๋ฒ•์— Constant warmup์ด ์œ ์šฉ
    • ๋Œ€๊ทœ๋ชจ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํฌ๊ธฐ kn์œผ๋กœ ํ•œ ImageNet ์‹คํ—˜์—์„œ, ์ฒ˜์Œ 5 epochs๋™์•ˆ ํ•™์Šต๋ฅ  η๋กœ ํ›ˆ๋ จํ•จ. ๊ทธ ํ›„ ๋ชฉํ‘œ ํ•™์Šต๋ฅ η^=kη\hat{\eta} = k\eta๋กœ ๋Œ์•„๊ฐ€๋ ค๊ณ  ํ–ˆ์Œ. ๊ทธ๋Ÿฌ๋‚˜ ํฐ k๊ฐ€ ์ฃผ์–ด์ง„ ๊ฒฝ์šฐ, ์ด Constant warmup๋งŒ์œผ๋กœ๋Š” ์ตœ์ ํ™” ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ์— ์ถฉ๋ถ„ํ•˜์ง€ ์•Š๋‹ค๋Š” ๊ฒƒ์„ ๋ฐœ๊ฒฌํ•˜์˜€์œผ๋ฉฐ, ๋‚ฎ์€ ํ•™์Šต๋ฅ  ์›œ์—… ๋‹จ๊ณ„์—์„œ์˜ ์ „ํ™˜์€ ํ›ˆ๋ จ ์˜ค๋ฅ˜๋ฅผ ๊ธ‰์ฆ์‹œํ‚ฌ ์ˆ˜ ์žˆ์Œ⇒ ์ด๋กœ ์ธํ•ด ์šฐ๋ฆฌ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์ ์ง„์  ์›œ์—…(gradual warmup)์„ ์ œ์•ˆ
  • gradual warmup
    • ํ•™์Šต๋ฅ ์„ ์ž‘์€ ๊ฐ’์—์„œ๋ถ€ํ„ฐ ํฐ ๊ฐ’์œผ๋กœ ์ ์ง„์ ์œผ๋กœ ์ฆ๊ฐ€์‹œํ‚ค๋Š” ๋Œ€์•ˆ์ ์ธ warmup
    • ํ•™์Šต๋ฅ ์˜ ๊ฐ‘์ž‘์Šค๋Ÿฌ์šด ์ฆ๊ฐ€๋ฅผ ํ”ผํ•˜๋ฉฐ, ํ›ˆ๋ จ ์ดˆ๊ธฐ์— ๊ฑด๊ฐ•ํ•œ ์ˆ˜๋ ด(healthy convergence)์„ ํ—ˆ์šฉํ•จ.
    • ์‹ค์ œ๋กœ ๋Œ€๊ทœ๋ชจ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํฌ๊ธฐ kn์œผ๋กœ ์‹œ์ž‘ํ•˜์—ฌ ํ•™์Šต๋ฅ ์„η\eta์—์„œ ์‹œ์ž‘ํ•˜์—ฌ ๊ฐ ๋ฐ˜๋ณต์—์„œ ์ƒ์ˆ˜๋Ÿ‰๋งŒํผ ์ฆ๊ฐ€์‹œ์ผœ 5 epochs ํ›„์—η^=kη\hat{\eta} = k\eta์— ๋„๋‹ฌํ•˜๋„๋ก ํ•จํ•ฉ๋‹ˆ๋‹ค (์ •ํ™•ํ•œ warmup ๊ธฐ๊ฐ„์— ๋Œ€ํ•œ ๊ฒฐ๊ณผ๋Š” ๊ฒฌ๊ณ ํ•จ). warmup ํ›„์—๋Š” ์›๋ž˜์˜ ํ•™์Šต๋ฅ ๋กœ ๋Œ์•„๊ฐ.

 

2.3 Batch Normalization with large minibatches

  • Batch Normalization : ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ์ฐจ์›์„ ๋”ฐ๋ผ ํ†ต๊ณ„๋Ÿ‰์„ ๊ณ„์‚ฐ
    • ๊ฐ ์ƒ˜ํ”Œ์˜ ์†์‹ค์˜ ๋…๋ฆฝ์„ฑ์„ ๊นจ๋œจ๋ฆผ
    • ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํฌ๊ธฐ์˜ ๋ณ€ํ™”๋Š” ์ตœ์ ํ™”๋˜๋Š” ์†์‹ค ํ•จ์ˆ˜์˜ ๊ธฐ๋ณธ ์ •์˜๋ฅผ ๋ณ€๊ฒฝ
  • ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํฌ๊ธฐ๋ฅผ ๋ณ€๊ฒฝํ•  ๋•Œ ์†์‹ค ํ•จ์ˆ˜๋ฅผ ๋ณด์กดํ•˜๊ธฐ ์œ„ํ•œ shortcut
    • ํ†ต์‹  ์˜ค๋ฒ„ํ—ค๋“œ๋ฅผ ํ”ผํ•˜๊ธฐ ์œ„ํ•œ ์‹ค์šฉ์ ์ธ ๊ณ ๋ ค
    • ์†์‹ค ํ•จ์ˆ˜๋ฅผ ๋ณด์กดํ•˜๋Š” ๋ฐ ํ•„์ˆ˜์ 
    • BN์ด ์ˆ˜ํ–‰๋˜๊ณ  ํ™œ์„ฑํ™”๊ฐ€ ์ƒ˜ํ”Œ ๊ฐ„์— ๊ณ„์‚ฐ๋  ๋•Œ๋Š” ์ด๋Ÿฌํ•œ ๊ฐ€์ •์ด ์„ฑ๋ฆฝํ•˜์ง€ ์•Š์Œ
      • ๊ฐ€์ • : per-sample loss์ธl(x,w)l(x, w)์ด ๋‹ค๋ฅธ ๋ชจ๋“  ์ƒ˜ํ”Œ๊ณผ ๋…๋ฆฝ์ ์ž„
      • ํฌ๊ธฐ๊ฐ€ n์ธ ๋‹จ์ผ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ B์˜ ์†์‹ค์„L(B,w)=1n∑x∈BlB(x,w)L(B, w) = \frac{1}{n}\sum_{x \in B}l_B(x, w)๋กœ ๋‚˜ํƒ€
      • BN์ด ์ ์šฉ๋œ ๊ฒฝ์šฐ, ํ›ˆ๋ จ ์„ธํŠธ๋Š” ์›๋ž˜ ํ›ˆ๋ จ ์„ธํŠธ X์—์„œ ์ถ”์ถœ๋œ ํฌ๊ธฐ๊ฐ€ n์ธ ๋ชจ๋“  ๊ตฌ๋ณ„๋˜๋Š” ๋ถ€๋ถ„์ง‘ํ•ฉ์„ ํฌํ•จ ⇒XnX^n
      • training loss L(w)
      L(w)=1โˆฃXnโˆฃ∑B∈XnL(B,w)L(w) = \frac{1}{|X^n|}\sum_{B \in X^n}L(B, w) \\ \\
      • ๋งŒ์•ฝ B๋ฅผXnX^n์˜ '๋‹จ์ผ ์ƒ˜ํ”Œ(single sample)'๋กœ ๋ณด๋ฉด, ๊ฐ ๋‹จ์ผ ์ƒ˜ํ”Œ B์˜ ์†์‹ค์ด ๋…๋ฆฝ์ ์œผ๋กœ ๊ณ„์‚ฐ๋จ⇒ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํฌ๊ธฐ n์ด BN ํ†ต๊ณ„๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๋ฐ ์ค‘์š”ํ•œ ๊ตฌ์„ฑ ์š”์†Œ์ž„์— ์œ ์˜ํ•ด์•ผ ํ•จ
      • ๊ฐ worker์˜ ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ์ƒ˜ํ”Œ ํฌ๊ธฐ n์ด ๋ณ€๊ฒฝ๋˜๋ฉด ์ตœ์ ํ™”๋˜๋Š” ๊ธฐ๋ณธ ์†์‹ค ํ•จ์ˆ˜ L์ด ๋ณ€๊ฒฝ๋˜๊ฒŒ ๋จ
      • BN์ด ๋‹ค๋ฅธ n์œผ๋กœ ๊ณ„์‚ฐํ•œ ํ‰๊ท /๋ถ„์‚ฐ ํ†ต๊ณ„๋Š” ์„œ๋กœ ๋‹ค๋ฅธ ์ˆ˜์ค€์˜ ๋ฌด์ž‘์œ„ ๋ณ€๋™์„ ๋‚˜ํƒ€๋ƒ„
      • ๋ถ„์‚ฐ๋œ(๊ทธ๋ฆฌ๊ณ  ๋ฉ€ํ‹ฐ-GPU) ํ›ˆ๋ จ
        • ๋งŒ์ผ worker๋‹น ์ƒ˜ํ”Œ ํฌ๊ธฐ n์ด ๊ณ ์ •๋˜์–ด ์žˆ๊ณ  ์ด ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํฌ๊ธฐ๊ฐ€ kn์ผ ๋•Œ
          • ๊ฐ ์ƒ˜ํ”ŒBjB_j๊ฐ€XnX^n์—์„œ ๋…๋ฆฝ์ ์œผ๋กœ ์„ ํƒ๋œ k๊ฐœ์˜ ์ƒ˜ํ”Œ ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋กœ ๋ณผ ์ˆ˜ ์žˆ์Œ.
          ⇒ ๊ธฐ๋ณธ ์†์‹ค ํ•จ์ˆ˜๋Š” ๋ณ€๊ฒฝ๋˜์ง€ ์•Š์œผ๋ฉฐ ์—ฌ์ „ํžˆXnX^n์—์„œ ์ •์˜๋จ
          • ์ด ๊ด€์ ์—์„œ BN ์„ค์ •์—์„œ, k๊ฐœ์˜ ๋ฏธ๋‹ˆ๋ฐฐ์น˜BjB_j๋ฅผ ๋ณธ ํ›„, (3)๊ณผ (4)๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๊ณ„์‚ฐ๋จ
            wt+k=wt−η∑j<kโ–ฝL(Bj,wt+j)w_{t+k} = w_t - \eta \sum_{j<k} \bigtriangledown L (B_j, w_{t+j})
            wt+k^=wt−η^1k∑j<kโ–ฝL(Bj,wt+j)\hat{w_{t+k}} = w_t - \hat{\eta} \frac{1}{k} \sum_{j<k} \bigtriangledown L (B_j, w_{t+j})
          • ๋ณธ ์—ฐ๊ตฌ์—์„œ๋Š”η^=kη\hat{\eta} = k\eta๋กœ ์„ค์ •ํ•˜๊ณ  worker ์ˆ˜ k๋ฅผ ๋ณ€๊ฒฝํ•  ๋•Œ worker๋‹น ์ƒ˜ํ”Œ ํฌ๊ธฐ n์„ ์ผ์ •ํ•˜๊ฒŒ ์œ ์ง€ํ•จ. ๊ทธ๋ฆฌ๊ณ  ๋‹ค์–‘ํ•œ ๋ฐ์ดํ„ฐ์…‹๊ณผ ๋„คํŠธ์›Œํฌ์—์„œ ์ž˜ ๋™์ž‘ํ•œ n = 32๋ฅผ ์‚ฌ์šฉ
          • ๋งŒ์•ฝ n์ด ์กฐ์ •๋œ๋‹ค๋ฉด, ์ด๋Š” BN์˜ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋กœ ๊ฐ„์ฃผ๋˜์–ด์•ผ ํ•˜๋ฉฐ ๋ถ„์‚ฐ ํ›ˆ๋ จ์ด ์•„๋‹˜.
          • ๋˜ํ•œ BN ํ†ต๊ณ„๋Š” ํ†ต์‹ ์„ ์ค„์ด๊ธฐ ์œ„ํ•ด์„œ๋งŒ์ด ์•„๋‹ˆ๋ผ ์ตœ์ ํ™”๋˜๋Š” ๊ธฐ๋ณธ ์†์‹ค ํ•จ์ˆ˜๋ฅผ ๋™์ผํ•˜๊ฒŒ ์œ ์ง€ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋„ ๋ชจ๋“  worker๋ฅผ ๋Œ€์ƒ์œผ๋กœ ๊ณ„์‚ฐํ•˜๋ฉด ์•ˆ๋จ
  • ๋ถ„์‚ฐ ๊ตฌํ˜„ ์˜ค๋ฅ˜๋“ค์€ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ์˜ ์ •์˜๋ฅผ ๋ณ€๊ฒฝํ•˜์—ฌ ํ›ˆ๋ จ๋˜๋Š” ๋ชจ๋ธ์˜ ์˜ค์ฐจ๊ฐ€ ์˜ˆ์ƒ๋ณด๋‹ค ๋†’์•„์ง€๊ฒŒ ํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ์ด๋Ÿฌํ•œ ๋ฌธ์ œ๋“ค์€ ๋ฐœ๊ฒฌํ•˜๊ธฐ ์–ด๋ ค์šธ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜์˜ ์ฃผ์žฅ๋“ค์€ ๋ช…ํ™•ํ•˜์ง€๋งŒ, ๊ธฐ๋ณธ solver๋ฅผ ์ถฉ์‹คํžˆ ๊ตฌํ˜„ํ•˜๊ธฐ ์œ„ํ•ด ๋ช…์‹œ์ ์œผ๋กœ ๊ณ ๋ คํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”

Weight decay

  • ์†์‹ค ํ•จ์ˆ˜์˜ L2 ์ •๊ทœํ™” ํ•ญ์˜ ๊ธฐ์šธ๊ธฐ์˜ ๊ฒฐ๊ณผ
  • per-sample loss
    l(x,w)=λ2โˆฃโˆฃwโˆฃโˆฃ2+ε(x,w)l(x, w) = \frac{\lambda }{2}||w||^2 + \varepsilon (x, w)
    • λ2โˆฃโˆฃwโˆฃโˆฃ2\frac{\lambda}{2} ||w||^2 : ๊ฐ€์ค‘์น˜์— ๋Œ€ํ•œ ์ƒ˜ํ”Œ ๋…๋ฆฝ์ ์ธ L2 ์ •๊ทœํ™”
    • ε(x,w)\varepsilon (x, w) : ํฌ๋กœ์Šค ์—”ํŠธ๋กœํ”ผ์˜ ์ƒ˜ํ”Œ ์ข…์†์ ์ธ ํ•ญ(sample-dependent term)
  • SGD update
wt+1=wt−ηλwt−η1n∑x∈Bโ–ฝε(x,wt)w_{t+1} = w_t - \eta \lambda w_t - \eta\frac{1}{n}\sum_{x \in B} \bigtriangledown \varepsilon (x, w_t)
  • ์‹ค์ œ๋กœ๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ์—ญ์ „ํŒŒ๋ฅผ ํ†ตํ•ด ์ƒ˜ํ”Œ ์ข…์†์ ์ธ ํ•ญ∑โ–ฝε(x,wt)\sum \bigtriangledown \varepsilon (x, w_t) ๋งŒ ๊ณ„์‚ฐ๋จ
  • ๊ฐ€์ค‘์น˜ ๊ฐ์‡  ํ•ญλwt\lambda w_t ๋Š” ๋ณ„๋„๋กœ ๊ณ„์‚ฐ๋˜์–ดε(x,wt)\varepsilon (x, w_t)์— ๊ธฐ์—ฌํ•œ ๊ทธ๋ ˆ๋””์–ธํŠธ์— ์ถ”๊ฐ€๋จ
  • ๊ฐ€์ค‘์น˜ ๊ฐ์‡  ํ•ญ์ด ์—†๋Š” ๊ฒฝ์šฐ, ํ•™์Šต๋ฅ ์„ ์กฐ์ ˆํ•˜๋Š” ๋‹ค์–‘ํ•œ ๋ฐฉ๋ฒ•์ด ์žˆ์Œ
    • ε(x,wt)\varepsilon (x, w_t)ํ•ญ์„ ์กฐ์ ˆํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ํฌํ•จ
    • ๊ทธ๋Ÿฌ๋‚˜ ์ด๋Š” ํ•ญ์ƒ ์„ฑ๋ฆฝํ•˜์ง€ ์•Š์Œ.
    Remark 1: Scaling the cross-entropy loss is not equivalent to scaling the learning rate
  •  

Momentum correction

  • Momentum SGD
    • ๋ฐ”๋‹๋ผ SGD์— ๋Œ€ํ•œ ํ”ํžˆ ์ฑ„ํƒ๋˜๋Š” modification์ค‘ ํ•˜๋‚˜
    • ๋ชจ๋ฉ˜ํ…€ SGD์˜ ์ฐธ์กฐ ๊ตฌํ˜„
      ut+1=mut+1nโ–ฝl(x,wt)−−−−(9)u_{t+1} = mu_t + \frac{1}{n} \bigtriangledown l (x, w_t)----(9)
      wt+1=wt−ηut+1−−−−(9)w_{t+1} = w_t - \eta u_{t+1}----(9)
      • m : momentum decay factor
      • u : update tensor
    • ์ธ๊ธฐ ์žˆ๋Š” ๋ณ€ํ˜•
      • ํ•™์Šต๋ฅ  η๋ฅผ ์—…๋ฐ์ดํŠธ ํ…์„œ์— ํก์ˆ˜ํ•˜๋Š” ๊ฒƒ
        vt+1=mut+η1n∑x∈Bโ–ฝl(x,wt)−−−−(10)v_{t+1} = mu_t + \eta \frac{1}{n}\sum_{x \in B} \bigtriangledown l (x, w_t)----(10)
        wt+1=wt−vt+1−−−−(10)w_{t+1} = w_t - v_{t+1}----(10)
      • ๊ณ ์ •๋œη\eta์— ๋Œ€ํ•ด 2๊ฐ€์ง€๊ฐ€ ๋™์ผํ•จ
      • u๊ฐ€ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ์—๋งŒ ์˜์กดํ•˜๊ณ η\eta์™€๋Š” ๋…๋ฆฝ์ ์ธ ๋ฐ˜๋ฉด, v๋Š”η\eta์™€ ์–ฝํ˜€ ์žˆ๋‹ค๋Š” ์ ์— ์ฃผ๋ชฉํ•ด์•ผ ํ•จ.
      • η\eta๊ฐ€ ๋ณ€๊ฒฝ๋ฌ์„ ๋•Œ (9)์˜ ์ฐธ์กฐ๋ณ€ํ˜•๊ณผ ๋™์ผ์„ฑ์„ ์œ ์ง€๋˜๊ธฐ ์œ„ํ•ด v์˜ ์—…๋ฐ์ดํŠธ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์•„์ ธ์•ผ ํ•จ
        vt+1=mηt+1ηt+ηt+11n∑โ–ฝl(x,wt)v_{t+1} = m \frac{\eta_{t+1}}{\eta_t} + \eta_{t+1}\frac{1}{n}\sum \bigtriangledown l (x, w_t)
        • ηt+1ηt\frac{\eta_{t+1}}{\eta_{t}} : ๋ชจ๋ฉ˜ํ…€ ๋ณด์ •(momentum correction)
        • ηt+1>>ηt\eta_{t+1} >> \eta_t ์ธ ๊ฒฝ์šฐ ํ›ˆ๋ จ์„ ์•ˆ์ •ํ™”์‹œํ‚ค๋Š” ๋ฐ ์ค‘์š”ํ•จ.
        ⇒ ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด history ํ•ญvtv_t๊ฐ€ ๋„ˆ๋ฌด ์ž‘์•„์ ธ์„œ ๋ถˆ์•ˆ์ •์„ฑ์„ ์ดˆ๋ž˜ํ•จ(ηt+1<ηt\eta_{t+1} < \eta_t์ผ ๋•Œ ๋ชจ๋ฉ˜ํ…€ ๋ณด์ •์€ ๋œ ์ค‘์š”ํ•จ)Remark 2: Apply momentum correction after changing learning rate if using(10)

Gradient aggregation

  • k ๊ฐœ์˜ worker๋งˆ๋‹ค ๊ฐ๊ฐ ํฌ๊ธฐ๊ฐ€ n์ธ ๋ฏธ๋‹ˆ๋ฐฐ์น˜๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋Š” ๊ฒฝ์šฐ, (4)์— ๋”ฐ๋ผ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ ์ง‘๊ณ„๋Š” kn๊ฐœ์˜ ์ „์ฒด ์˜ˆ์ œ ์„ธํŠธ์— ๋Œ€ํ•ด ๋‹ค์Œ ์‹์ด ์ˆ˜ํ–‰๋˜์–ด์•ผ ํ•จ.
1kn∑j∑x∈Bjl(x,wt)\frac{1}{kn}\sum_j\sum_{x \in B_j}l(x, w_t)
  • ์†์‹ค ๋ ˆ์ด์–ด
    • ์ผ๋ฐ˜์ ์œผ๋กœ ์ž์ฒด์˜ ๋กœ์ปฌ ์ž…๋ ฅ์— ๋Œ€ํ•œ ํ‰๊ท  ์†์‹ค์„ ๊ณ„์‚ฐํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ๊ตฌํ˜„
    • ๊ฐ ์ž‘์—…์ž์˜ ์†์‹ค∑l(x,wt)/n\sum l(x, w_t)/n์„ ๊ณ„์‚ฐํ•˜๋Š” ๊ฒƒ๊ณผ ๋™์ผํ•จ
  • ์ƒ๊ธฐ ๋‚ด์šฉ์„ ๊ณ ๋ คํ•˜๋ฉด ์˜ฌ๋ฐ”๋ฅธ ์ง‘๊ณ„๋Š” ๋ˆ„๋ฝ๋œ 1/k์š”์†Œ๋ฅผ ๋ณต์›ํ•˜๊ธฐ ์œ„ํ•ด k๊ฐœ์˜ ๊ทธ๋ ˆ๋””์–ธํŠธ๋ฅผ ํ‰๊ท ํ™”ํ•ด์•ผํ•จ.
    • ๊ทธ๋Ÿฌ๋‚˜, allreduce [11]์™€ ๊ฐ™์€ ํ‘œ์ค€ ํ†ต์‹  ์›์‹œ๋Š” ํ‰๊ท ์ด ์•„๋‹Œ ํ•ฉ์„ ์ˆ˜ํ–‰
  • ๋”ฐ๋ผ์„œ 1/k์Šค์ผ€์ผ๋ง์„ ์†์‹ค์— ํก์ˆ˜ํ•˜๋Š” ๊ฒƒ์ด ๋” ํšจ์œจ์ ์ด๋ฉฐ, ์ด ๊ฒฝ์šฐ์—๋Š” ์†์‹ค์— ๋Œ€ํ•œ ์ž…๋ ฅ์˜ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๋งŒ ์Šค์ผ€์ผ๋งํ•˜๋ฉด ๋˜๋ฏ€๋กœ ์ „์ฒด ๊ทธ๋ ˆ์ด๋””์–ธํŠธ ๋ฒกํ„ฐ๋ฅผ ์Šค์ผ€์ผ๋งํ•  ํ•„์š”๊ฐ€ ์—†์Œ.Remark 3: Normalize the per-worker loss by total minibatch size kn, not per-worker size n
  • ๋˜ํ•œ ‘k๋ฅผ ์ทจ์†Œ’ํ•˜๊ธฐ์œ„ํ•ดη^=η\hat{\eta} = \eta ์„ค์ •ํ•˜๊ณ  ์†์‹ค์„ 1/n์œผ๋กœ ์ •๊ทœํ™”ํ•˜๋Š” ๊ฒƒ์€ ์ž˜๋ชป๋  ์ˆ˜ ์žˆ์Œ
    • ์ด๋Š” ์ž˜๋ชป๋œ ๊ฐ€์ค‘์น˜ ๊ฐ์‡ ๋กœ ์ด์–ด์ง (← remark 1 ์ฐธ์กฐ)

Data shuffling

  • SGD๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๋ณต์›์ถ”์ถœ๋กœ ์ƒ˜ํ”Œ๋งํ•˜๋Š” ํ”„๋กœ์„ธ์Šค๋กœ ๋ถ„์„
    • ์‹ค์ œ๋กœ ์ผ๋ฐ˜์ ์ธ SGD ๊ตฌํ˜„์—์„œ๋Š” ๊ฐ SGD ์—ํฌํฌ๋งˆ๋‹ค ํ›ˆ๋ จ ์„ธํŠธ๋ฅผ ๋ฌด์ž‘์œ„๋กœ ์„ž์–ด์ฃผ๋Š” ๊ฒƒ์ด ํ”ํ•˜๋ฉฐ, ์ด๋Š” ๋” ๋‚˜์€ ๊ฒฐ๊ณผ๋ฅผ ์–ป์„ ์ˆ˜์žˆ์Œ.
    • ๋ฌด์ž‘์œ„ ์…”ํ”Œ๋ง์„ ์‚ฌ์šฉํ•˜๋Š” ๊ธฐ์ค€์„ (K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016)๊ณผ ๊ณต์ •ํ•œ ๋น„๊ต๋ฅผ ์ œ๊ณตํ•˜๊ธฐ ์œ„ํ•ด, k worker์— ์˜ํ•ด ์ˆ˜ํ–‰๋œ ํ•œ ์—ํฌํฌ์˜ ์ƒ˜ํ”Œ์€ ํ›ˆ๋ จ ์„ธํŠธ์˜ ๋‹จ์ผ ์ผ๊ด€๋œ ๋ฌด์ž‘์œ„ ์„ž๊ธฐ์—์„œ ๊ฐ€์ ธ์˜ค๋„๋ก ํ•จ.
      • ์ด๋ฅผ ๋‹ฌ์„ฑํ•˜๊ธฐ ์œ„ํ•ด ๊ฐ ์—ํฌํฌ๋งˆ๋‹ค k ๋ถ€๋ถ„์œผ๋กœ ๋‚˜๋‰˜์–ด์ง„ ๋ฌด์ž‘์œ„ ์„ž๊ธฐ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉฐ, ๊ฐ ๋ถ€๋ถ„์€ k worker ์ค‘ ํ•˜๋‚˜์— ์˜ํ•ด ์ฒ˜๋ฆฌ
      • ์—ฌ๋Ÿฌ worker์—์„œ ๋ฌด์ž‘์œ„ ์…”ํ”Œ๋ง์„ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ๊ตฌํ˜„ํ•˜์ง€ ์•Š์œผ๋ฉด ๋šœ๋ ทํ•œ ๋‹ค๋ฅธ ๋™์ž‘์„ ์œ ๋ฐœ ๊ฐ€๋Šฅํ•จ⇒๊ฒฐ๊ณผ์™€ ๊ฒฐ๋ก ์„ ์˜ค์—ผ
Remark 4: Use a single random shuffling of the training data (per epoch) that is divided amongst all k workers

 

  • ํ•˜๋‚˜์˜ Big Basin ์„œ๋ฒ„์—์„œ์˜ 8๊ฐœ์˜ GPU๋ฅผ ๋„˜์–ด์„œ ๊ทœ๋ชจ๋ฅผ ํ™•์žฅํ•˜๋ ค๋ฉด , ๊ทธ๋ ˆ์ด๋””์–ธํŠธ ์ง‘๊ณ„๋Š” ๋„คํŠธ์›Œํฌ ์ƒ์˜ ์—ฌ๋Ÿฌ ์„œ๋ฒ„์— ๊ฑธ์ณ ์ด๋ค„์ ธ์•ผ ํ•จ. ๊ฑฐ์˜ ์™„๋ฒฝํ•œ ์„ ํ˜• ํ™•์žฅ์„ ํ—ˆ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์ง‘๊ณ„๊ฐ€ ์—ญ์ „ํŒŒ์™€ ๋ณ‘๋ ฌ๋กœ ์ˆ˜ํ–‰๋˜์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” ์ธต ๊ฐ„์— ๊ทธ๋ ˆ์ด๋””์–ธํŠธ ๊ฐ„์— ๋ฐ์ดํ„ฐ ์˜์กด์„ฑ์ด ์—†๊ธฐ ๋•Œ๋ฌธ์— ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ํ•œ ์ธต์˜ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๊ฐ€ ๊ณ„์‚ฐ๋˜๋ฉด ์ฆ‰์‹œ ํ•ด๋‹น ์ธต์— ๋Œ€ํ•œ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๊ฐ€ ์ž‘์—…์ž๋“ค ๊ฐ„์— ์ง‘๊ณ„๋˜๊ณ , ๋™์‹œ์— ๋‹ค์Œ ์ธต์— ๋Œ€ํ•œ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ ๊ณ„์‚ฐ์ด ๊ณ„์†๋ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ์—์„œ๋Š” ์ด์— ๋Œ€ํ•œ ์ž์„ธํ•œ ๋‚ด์šฉ์„ ์„ค๋ช…ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.

4.1 Gradient Aggregation

  • ๊ฐ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ์— ๋Œ€ํ•ด ์ง‘๊ณ„๋Š” MPI ์ง‘ํ•ฉ ์—ฐ์‚ฐ(MPI Allreduce)[11]์™€ ์œ ์‚ฌํ•œ allreduce ์ž‘์—…์„ ์‚ฌ์šฉํ•˜์—ฌ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค. Allreduce๊ฐ€ ์‹œ์ž‘๋˜๊ธฐ ์ „์— ๊ฐ GPU๋Š” ๋กœ์ปฌ๋กœ ๊ณ„์‚ฐ๋œ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๋ฅผ ๊ฐ–๊ณ  ์žˆ๊ณ , allreduce๊ฐ€ ์™„๋ฃŒ๋˜๋ฉด ๊ฐ GPU๋Š” ๋ชจ๋“  k๊ฐœ์˜ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ์˜ ํ•ฉ์„ ๊ฐ–์Šต๋‹ˆ๋‹ค. ๋งค๊ฐœ๋ณ€์ˆ˜์˜ ์ˆ˜๊ฐ€ ์ฆ๊ฐ€ํ•˜๊ณ  GPU์˜ ๊ณ„์‚ฐ ์„ฑ๋Šฅ์ด ํ–ฅ์ƒ๋จ์— ๋”ฐ๋ผ ์ง‘๊ณ„ ๋น„์šฉ์„ backprop ๋‹จ๊ณ„์—์„œ ์ˆจ๊ธฐ๊ธฐ๊ฐ€ ๋” ์–ด๋ ค์›Œ์ง‘๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ํšจ๊ณผ๋ฅผ ๊ทน๋ณตํ•˜๊ธฐ ์œ„ํ•œ ํ›ˆ๋ จ ๊ธฐ์ˆ ์€ ์ด ์ž‘์—…์˜ ๋ฒ”์œ„๋ฅผ ๋ฒ—์–ด๋‚ฉ๋‹ˆ๋‹ค (์˜ˆ: ์–‘์žํ™”๋œ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ(quantized gradient) [18], ๋ธ”๋ก-๋ชจ๋ฉ˜ํ…€ SGD [6]). ๊ทธ๋Ÿฌ๋‚˜ ์ด ์ž‘์—…์˜ ๊ทœ๋ชจ์—์„œ๋Š” ์ตœ์ ํ™”๋œ allreduce ๊ตฌํ˜„์„ ์‚ฌ์šฉํ•˜์—ฌ ๊ฑฐ์˜ ์„ ํ˜•์ ์ธ SGD ์Šค์ผ€์ผ๋ง์„ ๋‹ฌ์„ฑํ•  ์ˆ˜ ์žˆ์–ด, ์ง‘๋‹จ ํ†ต์‹ ์ด ๋ณ‘๋ชฉ์ด ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.
  • allreduce ๊ตฌํ˜„ 3๋‹จ๊ณ„ : ์„œ๋ฒ„ ๋‚ด๋ถ€ ๋ฐ ์„œ๋ฒ„ ๊ฐ„ ํ†ต์‹ ์„ ์œ„ํ•ด์„œ.
    ๐Ÿ’ก
    1. ์„œ๋ฒ„ ๋‚ด์˜ 8๊ฐœ GPU์—์„œ ๊ฐ๊ฐ์˜ ๋ฒ„ํผ๊ฐ€ ๊ฐ ์„œ๋ฒ„์— ๋Œ€ํ•ด ํ•˜๋‚˜์˜ ๋‹จ์ผ ๋ฒ„ํผ๋กœ ํ•ฉ์‚ฐ
    1. ๊ฒฐ๊ณผ ๋ฒ„ํผ๋Š” ๋ชจ๋“  ์„œ๋ฒ„ ๊ฐ„์— ๊ณต์œ ๋˜์–ด ํ•ฉ์‚ฐ
    1. ๊ฒฐ๊ณผ๊ฐ€ ๊ฐ GPU๋กœ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ๋จ
    • ์„œ๋ฒ„ ๊ฐ„ allreduce๋ฅผ ์œ„ํ•ด ๋Œ€์—ญํญ ์ œํ•œ ์‹œ๋‚˜๋ฆฌ์˜ค์— ๋Œ€ํ•œ ๋‘ ๊ฐ€์ง€ ์ตœ๊ณ ์˜ ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ๊ตฌํ˜„
      • ์žฌ๊ท€์ ์ธ ๋ฐ˜๊ฐ ๋ฐ ๋ฐฐ๊ฐ€ ์•Œ๊ณ ๋ฆฌ์ฆ˜
        • 2 log2(p) ํ†ต์‹  ๋‹จ๊ณ„๋กœ ์ด๋ฃจ์–ด์ง
        • reduce-scatter ์ง‘ํ•ฉ์— ์ด์–ด allgather๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์Œ
          • process
            • ์„œ๋ฒ„๋Š” ์Œ์œผ๋กœ ํ†ต์‹ ํ•˜๋ฉฐ (๋žญํฌ 0์€ 1๊ณผ, 2๋Š” 3๊ณผ ๋“ฑ๋“ฑ), ์ž…๋ ฅ ๋ฒ„ํผ์˜ ๋‹ค๋ฅธ ๋ฐ˜์ชฝ์— ๋Œ€ํ•ด ๋ณด๋‚ด๊ณ  ๋ฐ›์Œ (๋žญํฌ 0์€ ๋ฒ„ํผ์˜ ๋‘ ๋ฒˆ์งธ ์ ˆ๋ฐ˜์„ 1์—๊ฒŒ ๋ณด๋‚ด๊ณ  1๋กœ๋ถ€ํ„ฐ ๋ฒ„ํผ์˜ ์ฒซ ๋ฒˆ์งธ ์ ˆ๋ฐ˜์„ ๋ฐ›์Œ)
            • ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰ํ•˜๊ธฐ ์ „์— ์ˆ˜์‹ ๋œ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•œ ์ถ•์†Œ๊ฐ€ ์ˆ˜ํ–‰๋˜๊ณ , ๋‹ค์Œ ๋‹จ๊ณ„์—์„œ๋Š” ๋ชฉ์ ์ง€ ๋žญํฌ๊นŒ์ง€์˜ ๊ฑฐ๋ฆฌ๊ฐ€ ๋‘ ๋ฐฐ๋กœ ๋Š˜์–ด๋‚˜๋ฉด์„œ ๋ณด๋‚ด๊ณ  ๋ฐ›์€ ๋ฐ์ดํ„ฐ๊ฐ€ ์ ˆ๋ฐ˜์œผ๋กœ ์ค„์–ด๋“ค์Œ
            • reduce-scatter ๋‹จ๊ณ„๊ฐ€ ์™„๋ฃŒ๋˜๋ฉด ๊ฐ ์„œ๋ฒ„์—๋Š” ์ตœ์ข… ์ถ•์†Œ๋œ ๋ฒกํ„ฐ์˜ ์ผ๋ถ€๊ฐ€ ์žˆ์Œ.
            • allgather ๋‹จ๊ณ„
              • reduce-scatter์—์„œ์˜ ํ†ต์‹  ํŒจํ„ด์„ ์—ญ์œผ๋กœ ์ถ”์ ํ•˜์—ฌ ์ตœ์ข… ์ถ•์†Œ๋œ ๋ฒกํ„ฐ์˜ ์ผ๋ถ€๋ฅผ ๊ฐ„๋‹จํžˆ ์—ฐ๊ฒฐ
            • ๊ฐ ์„œ๋ฒ„์—์„œ reduce-scatter์—์„œ ๋ณด๋‚ด๊ณ  ์žˆ๋˜ ๋ฒ„ํผ์˜ ์ผ๋ถ€๊ฐ€ allgather์—์„œ ์ˆ˜์‹ ๋˜๊ณ , ๋ฐ›๋˜ ๋ถ€๋ถ„์€ ์ด์ œ ๋ณด๋‚ด์ง
            • ์„œ๋ฒ„์˜ ์ˆ˜๊ฐ€ 2์˜ ๊ฑฐ๋“ญ์ œ๊ณฑ์ด ์•„๋‹Œ ๊ฒฝ์šฐ์— ๋Œ€์‘ํ•˜๊ธฐ ์œ„ํ•ด ์ด์ง„ ๋ธ”๋ก ์•Œ๊ณ ๋ฆฌ์ฆ˜ [30]์„ ์‚ฌ์šฉ
              • ์ด์ง„ ๋ธ”๋ก ์•Œ๊ณ ๋ฆฌ์ฆ˜ : ์„œ๋ฒ„๊ฐ€ 2์˜ ๊ฑฐ๋“ญ์ œ๊ณฑ ๋ธ”๋ก์œผ๋กœ ๋ถ„ํ• ๋˜๊ณ  ๋‘ ๊ฐœ์˜ ์ถ”๊ฐ€ ํ†ต์‹  ๋‹จ๊ณ„๊ฐ€ ์‚ฌ์šฉ๋˜๋Š” ๋ฐ˜๊ฐ/๋ฐฐ๊ฐ€ ์•Œ๊ณ ๋ฆฌ์ฆ˜์˜ ์ผ๋ฐ˜ํ™”๋œ ๋ฒ„์ „
              • ์•„๋Š” ๋ธ”๋ก ๋‚ด๋ถ€ reduce-scatter ํ›„์™€ ๋ธ”๋ก ๋‚ด๋ถ€ allgather ์ „์— ๊ฐ๊ฐ ํ•œ ๋ฒˆ ์‚ฌ์šฉ๋จ
              • ๊ฑฐ๋“ญ์ œ๊ณฑ์ด ์•„๋‹Œ ๊ฒฝ์šฐ ์ผ๋ถ€ ๋ถ€ํ•˜ ๋ถˆ๊ท ํ˜•์ด ๊ฑฐ๋“ญ์ œ๊ณฑ๊ณผ ๋น„๊ตํ•˜์—ฌ ๋ฐœ์ƒํ•˜์ง€๋งŒ, ํ˜„ ๋…ผ๋ฌธ์—์„œ๋Š” ์„ฑ๋Šฅ์ €ํ•˜๋ฅผ ๊ด€์ฐฐํ•˜์ง€ ๋ชปํ•จ.
      • ๋ฒ„ํ‚ท ์•Œ๊ณ ๋ฆฌ์ฆ˜ (๋ง ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด๋ผ๊ณ ๋„ ํ•จ)
        • 2(p−1) ๋‹จ๊ณ„๋กœ ์ด๋ฃจ์–ด์ง.
      → ์–‘์ชฝ ๋ชจ๋‘ ๊ฐ ์„œ๋ฒ„๊ฐ€ 2(p−1)b ๋ฐ”์ดํŠธ์˜ ๋ฐ์ดํ„ฐ๋ฅผ ๋ณด๋‚ด๊ณ  ๋ฐ›๋Š”๋ฐ, ์—ฌ๊ธฐ์„œ b๋Š” ๋ฒ„ํผ ํฌ๊ธฐ(๋ฐ”์ดํŠธ)์ด๊ณ  p๋Š” ์„œ๋ฒ„์˜ ์ˆ˜
      • ๋ฐ˜๊ฐ/๋ฐฐ๊ฐ€ ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด ๋Œ€๊ฐœ latency-limited ์‹œ๋‚˜๋ฆฌ์˜ค์—์„œ ๋” ๋น ๋ฅด๊ฒŒ ๋™์ž‘(์ฆ‰, ์ž‘์€ ๋ฒ„ํผ ํฌ๊ธฐ ๋ฐ/๋˜๋Š” ํฐ ์„œ๋ฒ„ ์ˆ˜์˜ ๊ฒฝ์šฐ ,์•ฝ 3).

4.2 software

  • ํ†ต์‹  ์ˆ˜์ง‘์„ ์œ„ํ•œ allreduce ์•Œ๊ณ ๋ฆฌ์ฆ˜์€ Gloo ๊นƒํ—ˆ๋ธŒ์— ์žˆ์Œ.
    • ๋ณ‘๋ ฌ๋กœ ์—ฌ๋Ÿฌ allreduce ์ธ์Šคํ„ด์Šค๋ฅผ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ์ถ”๊ฐ€ ๋™๊ธฐํ™”๊ฐ€ ํ•„์š”ํ•˜์ง€ ์•Š๋Š” ์—ฌ๋Ÿฌ ํ†ต์‹  context๋ฅผ ์ง€์›
    • ๋กœ์ปฌ ๋ฆฌ๋•์…˜ ๋ฐ ๋ธŒ๋กœ๋“œ์บ์ŠคํŠธ (๋‹จ๊ณ„ (1) ๋ฐ (3)๋กœ ์„ค๋ช…๋จ)์€ ๊ฐ€๋Šฅํ•œ ๊ฒฝ์šฐ ์„œ๋ฒ„ ๊ฐ„ allreduce์™€ ํŒŒ์ดํ”„๋ผ์ธํ™”๋จ
  • Caffe2
    • ํ›ˆ๋ จ ๋ฐ˜๋ณต์„ ๋‚˜ํƒ€๋‚ด๋Š” ์ปดํ“จํŠธ ๊ทธ๋ž˜ํ”„์˜ ๋ฉ€ํ‹ฐ์Šค๋ ˆ๋“œ ์‹คํ–‰์„ ์ง€์›
    • ์„œ๋ธŒ๊ทธ๋ž˜ํ”„ ๊ฐ„์— ๋ฐ์ดํ„ฐ ์˜์กด์„ฑ์ด ์—†๋Š” ๊ฒฝ์šฐ ์—ฌ๋Ÿฌ ์Šค๋ ˆ๋“œ๊ฐ€ ๊ทธ ์„œ๋ธŒ๊ทธ๋ž˜ํ”„๋ฅผ ๋ณ‘๋ ฌ๋กœ ์‹คํ–‰๊ฐ€๋Šฅ
    • ์ด๋ฅผ backprop์— ์ ์šฉํ•˜๋ฉด ๋กœ์ปฌ ๊ทธ๋ ˆ์ด๋””์–ธํŠธ๊ฐ€ ์ˆœ์ฐจ์ ์œผ๋กœ ๊ณ„์‚ฐ๋  ์ˆ˜ ์žˆ๊ณ , allreduce๋‚˜ ๊ฐ€์ค‘์น˜ ์—…๋ฐ์ดํŠธ์™€ ๊ด€๋ จ์ด ์—†์Œ
      • ์ด๋Š” backprop ์ค‘์— ์‹คํ–‰ ๊ฐ€๋Šฅํ•œ ์„œ๋ธŒ๊ทธ๋ž˜ํ”„ ์ง‘ํ•ฉ์ด ์‹คํ–‰ ๊ฐ€๋Šฅํ•œ ์„œ๋ธŒ๊ทธ๋ž˜ํ”„๋ฅผ ์‹คํ–‰ํ•˜๋Š” ์†๋„๋ณด๋‹ค ๋” ๋นจ๋ฆฌ ์ฆ๊ฐ€ํ•  ์ˆ˜ ์žˆ์Œ์„ ์˜๋ฏธ
    • allreduce๋ฅผ ํฌํ•จํ•˜๋Š” ์„œ๋ธŒ๊ทธ๋ž˜ํ”„์˜ ๊ฒฝ์šฐ ๋ชจ๋“  ์„œ๋ฒ„๊ฐ€ ์‹คํ–‰ ๊ฐ€๋Šฅํ•œ ์„œ๋ธŒ๊ทธ๋ž˜ํ”„ ์ง‘ํ•ฉ์—์„œ ๋™์ผํ•œ ์„œ๋ธŒ๊ทธ๋ž˜ํ”„๋ฅผ ์‹คํ–‰ํ•˜๋„๋ก ์„ ํƒํ•ด์•ผ ํ•จ
      • ๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด ์„œ๋ฒ„๊ฐ€ ์„œ๋กœ ๊ต์ฐจํ•˜์ง€ ์•Š๋Š” ์„œ๋ธŒ๊ทธ๋ž˜ํ”„ ์ง‘ํ•ฉ์„ ์‹คํ–‰ํ•˜๋ ค๊ณ  ํ•  ๋•Œ ๋ถ„์‚ฐ ๋ฐ๋“œ๋ฝ์ด ๋ฐœ์ƒํ•  ์œ„ํ—˜์ด ์žˆ์Œ.
      • allreduce๊ฐ€ ์ง‘ํ•ฉ ์—ฐ์‚ฐ์ด๊ธฐ ๋•Œ๋ฌธ์— ์„œ๋ฒ„๋Š” ํƒ€์ž„์•„์›ƒ ๋ ๊ฑฐ์ž„
    • ์˜ฌ๋ฐ”๋ฅธ ์‹คํ–‰์„ ๋ณด์žฅํ•˜๊ธฐ ์œ„ํ•ด ์ด๋Ÿฌํ•œ ์„œ๋ธŒ๊ทธ๋ž˜ํ”„์— ๋Œ€ํ•œ ๋ถ€๋ถ„์ ์ธ ์ˆœ์„œ๋ฅผ ๋ถ€์—ฌํ•ด์•ผํ•จ.
      • ์ˆœํ™˜์ ์ธ ์ œ์–ด ์ž…๋ ฅ์„ ์‚ฌ์šฉํ•˜์—ฌ ๊ตฌํ˜„๋จ
      • n๋ฒˆ์งธ allreduce์˜ ์™„๋ฃŒ๊ฐ€ (n + c)๋ฒˆ์งธ allreduce์˜ ์‹คํ–‰์„ ๋ธ”๋ก ํ•ด์ œํ•˜๊ฒŒ ๋จ
      • ์—ฌ๊ธฐ์„œ c๋Š” ์ตœ๋Œ€ ๋™์‹œ allreduce ์‹คํ–‰ ํšŸ์ˆ˜
        • ์ด ์ˆซ์ž๋Š” ์ „์ฒด ์ปดํ“จํŠธ ๊ทธ๋ž˜ํ”„๋ฅผ ์‹คํ–‰ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š” ์Šค๋ ˆ๋“œ ์ˆ˜๋ณด๋‹ค ๋‚ฎ๊ฒŒ ์„ ํƒํ•ด์•ผ ํ•จ.

 

 

Comments