728x90

 

 

๋ฐฐ๊ฒฝ

 

Distilling the Knowledge in a Neural Network

A very simple way to improve the performance of almost any machine learning algorithm is to train many different models on the same data and then to average their predictions. Unfortunately, making predictions using a whole ensemble of models is cumbersome

arxiv.org

 

๋จธ์‹ ๋Ÿฌ๋‹ ๋””์ž์ธํŒจํ„ด์„ ์ฝ๋˜ ์ค‘ 4์žฅ distilling ์— ๊ด€ํ•œ ๊ธฐ๋ฒ•์ด ๋‚˜์™€ ์ฐพ์•„๋ณด๊ฒŒ ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

 

์š”์•ฝ

์ง€๊ธˆ๊นŒ์ง€์œผ ๋Œ€๊ทœ๋ชจ ๋จธ์‹ ๋Ÿฌ๋‹ ์‹œ์Šคํ…œ์€ ํ•™์Šต๊ณผ ๋ฐฐํฌ๋‹จ๊ณ„์—์„œ ๊ฐ™์œผ๋А ๋ชจ๋ธ์„ ์‚ฌ์šฉํ–ˆ๋Š”๋ฐ ์ด๋•Œ๋ฌธ์— ์ถ”๋ก  ๋ ˆ๋ฒจ์—์„œ ๋ฆฌ์†Œ์Šค๊ฐ€ ์ปค์ง„๋‹ค๋Š” ๋‹จ์ ์ด ์žˆ์—ˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ๊ฑฐ๋Œ€ํ•œ ๋ชจ๋ธ๋กœ๋ถ€ํ„ฐ ์ง€์‹์„ ํ•˜๋‚˜์˜ ์ž‘์€ ๋ชจ๋ธ๋กœ ์ „์ดํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ํ†ตํ•ด ์ด ์ œ์•ฝ์„ ๊ทน๋ณตํ•˜๋ ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์ด “์ฆ๋ฅ˜” ์˜ ํ‘œํ˜„์ž…๋‹ˆ๋‹ค.

๊ธฐ์กด์˜ hard label ์€ [1,0,0] ์ฒ˜๋Ÿผ ์ •ํ™•ํ•œ ํ™•๋ฅ ์„ ์•Œ๋ ค์ฃผ์—ˆ๋Š”๋ฐ, ์‹ค์ œ๋กœ๋Š” ๊ณ ์–‘์ด๋ฅผ ๋‹ฎ์€ ๊ฐœ๊ฐ€ ์žˆ์„์ˆ˜๋„ ์žˆ์œผ๋‹ˆ [0.6,0.4,0] ๊ฐ™์€ label ๋„ ์˜๋ฏธ์žˆ๋Š” ์ง€์‹์ผ ์ˆ˜ ์žˆ๋‹ค๋Š”๊ฒƒ์ด ์•„์ด๋””์–ด ์ž…๋‹ˆ๋‹ค.

๊ทธ๋ž˜์„œ ์ •๋ฆฌํ•˜์ž๋ฉด ๊ฑฐ๋Œ€ํ•œ ๋ชจ๋ธ์—์„œ ์‚ฐ์ถœ๋œ ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด์˜ ๊ฐ’๋“ค์„ ํ•™์Šต์— ํ™œ์šฉํ•˜๋Š” soft target ์œผ๋กœ ํ™œ์šฉํ•˜์ž๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.

Distillation

๊ธฐ์กด์˜ ์‹ ๊ฒฝ๋ง์€ ํด๋ž˜์Šค ๋ถ„๋ฅ˜ ํƒœ์Šคํฌ๋ฅผ ์ˆ˜ํ–‰ํ•  ๋•Œ output layer ์— softmax ๋ฅผ ์ทจํ•ด logit ์„ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.

์ผ๋ฐ˜์ ์œผ๋กœ T ๋Š” 1๋กœ ์„ธํŒ…๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค. T๋ฅผ ๋†’๊ฒŒ ํ• ์ˆ˜๋ก ํด๋ž˜์Šค ํ™•๋ฅ ๊ฐ’์ด soft ํ•˜๊ฒŒ ์ถœ๋ ฅ๋ฉ๋‹ˆ๋‹ค.

T ๊ฐ€ ์ปค์งˆ์ˆ˜๋ก ํด๋ž˜์Šค ๊ฐ’๋“ค์ด soft ํ•ด์ง€๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋…ผ๋ฌธ์—์„œ ์ œ์•ˆํ•œ “์ฆ๋ฅ˜” ์˜ ๋ฐฉ์‹์€ ๊ฑฐ๋Œ€ํ•œ ๋ชจ๋ธ์„ teacher model ๋กœ ํŠน์ˆ˜ํ•œ ๋ชฉ์ ์œผ๋กœ ๋งŒ๋“ค ์ž‘์€ ๋ชจ๋ธ์„ Student(distilled model) ์ด๋ผ ํ•ฉ๋‹ˆ๋‹ค.

t๋ฅผ ์ตœ๋Œ€๋กœ ํ•˜๊ณ , ๊ธฐ์กด์˜ transfer dataset ์„ teacher ๋ชจ๋ธ์— ๋„ฃ๊ณ  soft label (1) ์„ ์–ป์Šต๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  student ๋ชจ๋ธ์— inference ํ•ด์„œ soft prediction(2) ์„ ์–ป์Šต๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  student model ์—์„œ hard prediction(3)๊ฒฐ๊ณผ๋ฅผ ์–ป์Šต๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  (1) ์™€ (2) ์˜ cross entropy , (2) ์™€ (3) ์˜ cross entropy ๋ฅผ ๊ฐ€์ค‘ํ•ฉ ํ•˜๋Š” ๋ฐฉ์‹์ด ์„ฑ๋Šฅ์ด ์ข‹๋‹ค๊ณ  ํ•ฉ๋‹ˆ๋‹ค.

cross entropy gradient ์˜ ๋Š” ์•„๋ž˜์™€ ๊ฐ™์ด ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ Vi ๋Š” ํฐ ๋ชจ๋ธ์˜ ๊ฒฐ๊ณผ๊ฐ’์„ ๋‚˜ํƒ€๋‚ด๊ณ  pi ๋Š” soft label ์˜ ํ™•๋ฅ ๊ฐ’์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‹ˆ๊นŒ qi ์™€ pi ์˜ cross entropy ๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๋ฌธ์ œ๊ฐ€ logit ๊ฐ„์˜ ์ฐจ์ด๋ฅผ ๊ทผ์‚ฌํ•˜๋Š” ๋ฌธ์ œ๋กœ ๋ณ€ํ™˜์‹œํ‚ค๊ณ  ์žˆ์Œ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.

์ฆ๋ฅ˜์—์„œ Teacher Model ๊ณผ Student ์˜ ๋ชจ๋ธ output ์ฐจ์ด๋ฅผ ํ™œ์šฉํ•ด gradient ๊ณ„์‚ฐ์„ ํ•˜๋ ค๋Š” ์›€์ง์ž„์ž…๋‹ˆ๋‹ค.

์—ฌ๊ธฐ์„œ T๊ฐ€ ์ถฉ๋ถ„ํžˆ ํฌ๋‹ค๋ฉด ํ…Œ์ผ๋Ÿฌ ๊ทผ์‚ฌ๋ฅผ ํ™œ์šฉํ•ด ์•„๋ž˜์™€ ๊ฐ™์ด ๋‚˜ํƒ€๋‚ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค

 

๊ทธ๋ฆฌ๊ณ  logit ์˜ ํ‰๊ท ์ด 0 ์ด๋ผ๊ณ  ๊ฐ€์ •ํ•œ๋‹ค๋ฉด 0 ์œผ๋กœ ๋ณ€ํ™˜๋˜๋‹ˆ

student model output ์— ๋Œ€ํ•œ Cross entropy ๋ณ€ํ™”์œจ ์ฆ‰ gradient ๋Š” nt^1 ์— ๋ฐ˜๋น„๋ก€ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค. ๊ทธ๋ž˜์„œ ๊ฒฐ๋ก ์€ “T๊ฐ€ ์ถฉ๋ถ„ํžˆ ํฐ ์ƒํ™ฉ์—์„œ logit ๋“ค์˜ ํ‰๊ท ์ด 0 ์œผ๋กœ ์ฃผ์–ด์กŒ๋‹ค๋ฉด, Distillation ์€ 1/nt^2 ์„ ์ตœ์†Œํ™” ํ•˜๋Š” ๋ฌธ์ œ๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.

๋ฐ˜๋Œ€๋กœ T ๊ฐ€ ๋‚ฎ๋‹ค๋ฉด gradient ๋ฅผ ์ตœ๋Œ€ํ™” ์‹œํ‚ค๋‹ˆ ๋ชจ๋ธ์ด ๋„ˆ๋ฌด ์ž‘๋‹ค๋ฉด ์ค‘๊ฐ„์ •๋„์˜ temperature ๋ฅผ ์‚ฌ์šฉํ•˜๋Š”๊ฒƒ์ด ์ข‹๋‹ค๊ณ  ์ฃผ์žฅํ•ฉ๋‹ˆ๋‹ค

 

 

 

 

๊ธฐ์—ฌ๋„

๋ณธ ๋…ผ๋ฌธ์˜ ๊ฐ€์žฅ ํฐ ๊ธฐ์—ฌ๋Š” ๋‹ค์Œ ์„ธ ๊ฐ€์ง€๋กœ ์ •๋ฆฌํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

1. Soft Target์„ ํ™œ์šฉํ•œ ์ง€์‹ ์ „๋‹ฌ ๋ฐฉ์‹์˜ ์ •์‹ํ™”

๊ธฐ์กด์˜ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์€ ์ •๋‹ต ๋ ˆ์ด๋ธ”(one-hot ๋ฒกํ„ฐ)๋งŒ์„ ํ•™์Šต ์‹ ํ˜ธ๋กœ ์‚ฌ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋ณธ ๋…ผ๋ฌธ์€ ๋ชจ๋ธ์˜ ์ถœ๋ ฅ ํ™•๋ฅ  ๋ถ„ํฌ ์ „์ฒด๊ฐ€ ์ง€์‹์ด๋ผ๋Š” ๊ด€์ ์„ ์ œ์‹œํ•˜์˜€์Šต๋‹ˆ๋‹ค.

ํŠนํžˆ Teacher ๋ชจ๋ธ์ด ์ถœ๋ ฅํ•œ softmax ํ™•๋ฅ  ๋ถ„ํฌ์—๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์ •๋ณด๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

  • ํด๋ž˜์Šค ๊ฐ„ ์œ ์‚ฌ๋„
  • ๋ชจ๋ธ์ด ํ—ท๊ฐˆ๋ฆฌ๋Š” ์ •๋„
  • ๋ฐ์ดํ„ฐ ๋ถ„ํฌ์— ๋Œ€ํ•œ ์•”๋ฌต์  ๊ตฌ์กฐ

์ด๋ฅผ Student ๋ชจ๋ธ์ด ํ•™์Šตํ•˜๋„๋ก ๋งŒ๋“ค๋ฉด, ๋‹จ์ˆœํžˆ ์ •๋‹ต์„ ๋งž์ถ”๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ Teacher์˜ ํŒ๋‹จ ๊ตฌ์กฐ ์ž์ฒด๋ฅผ ๋ชจ๋ฐฉํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

์ด๋Š” ์ดํ›„ ๋ชจ๋“  Knowledge Distillation ์—ฐ๊ตฌ์˜ ์ถœ๋ฐœ์ ์ด ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.


2. Temperature ๊ธฐ๋ฐ˜ ํ™•๋ฅ  ๋ถ„ํฌ ์ œ์–ด ๋ฉ”์ปค๋‹ˆ์ฆ˜ ์ œ์•ˆ

๋…ผ๋ฌธ์€ softmax ํ•จ์ˆ˜์— temperature TTT๋ฅผ ๋„์ž…ํ•˜์—ฌ ํ™•๋ฅ  ๋ถ„ํฌ๋ฅผ ์กฐ์ ˆํ•˜๋Š” ๋ฐฉ์‹์„ ์ œ์•ˆํ•˜์˜€์Šต๋‹ˆ๋‹ค.

  • T=1์ผ ๊ฒฝ์šฐ ์ผ๋ฐ˜์ ์ธ softmax
  • T>1์ผ ๊ฒฝ์šฐ ๋ถ„ํฌ๊ฐ€ ํ‰ํƒ„ํ•ด์ง

Temperature๋ฅผ ๋†’์ด๋ฉด ๋ชจ๋ธ์˜ ํ™•์‹ (confidence)์ด ๋‚ฎ์•„์ง€๊ณ , ํด๋ž˜์Šค ๊ฐ„ ์ƒ๋Œ€์  ๊ด€๊ณ„ ์ •๋ณด๊ฐ€ ๋” ์ž˜ ๋“œ๋Ÿฌ๋‚ฉ๋‹ˆ๋‹ค. ์ด๋กœ ์ธํ•ด Student ๋ชจ๋ธ์€ ๋‹จ์ˆœ ์ •๋‹ต์ด ์•„๋‹ˆ๋ผ ํ™•๋ฅ  ๊ตฌ์กฐ ์ž์ฒด๋ฅผ ํ•™์Šตํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

์ด๋Ÿฌํ•œ ์ˆ˜ํ•™์  ๊ทผ์‚ฌ ๋ถ„์„์„ ํ†ตํ•ด, KD๊ฐ€ ๊ฒฐ๊ตญ logit ๊ฐ„ ์ฐจ์ด๋ฅผ ์ค„์ด๋Š” ๋ฌธ์ œ๋กœ ์ˆ˜๋ ดํ•œ๋‹ค๋Š” ๊ฒƒ์„ ๋ณด์ธ ์  ์—ญ์‹œ ์ค‘์š”ํ•œ ๊ธฐ์—ฌ์ž…๋‹ˆ๋‹ค.


3. ์•™์ƒ๋ธ” ๋ชจ๋ธ์˜ ์••์ถ• ๋ฐฉ๋ฒ• ์ œ์‹œ

๋…ผ๋ฌธ์€ ์—ฌ๋Ÿฌ ๊ฐœ์˜ ๋ชจ๋ธ์„ ์•™์ƒ๋ธ”ํ•˜์—ฌ ์–ป์€ ๊ณ ์„ฑ๋Šฅ Teacher ๋ชจ๋ธ์„ ๋‹จ์ผ Student ๋ชจ๋ธ๋กœ ์••์ถ•ํ•  ์ˆ˜ ์žˆ์Œ์„ ๋ณด์˜€์Šต๋‹ˆ๋‹ค.

์ด๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์‹ค์šฉ์  ์˜๋ฏธ๋ฅผ ๊ฐ€์ง‘๋‹ˆ๋‹ค.

  • ํ•™์Šต ๋‹จ๊ณ„์—์„œ๋Š” ๋Œ€๊ทœ๋ชจ ๋ชจ๋ธ ์‚ฌ์šฉ
  • ๋ฐฐํฌ ๋‹จ๊ณ„์—์„œ๋Š” ๊ฒฝ๋Ÿ‰ ๋ชจ๋ธ ์‚ฌ์šฉ
  • ์ถ”๋ก  ์†๋„ ํ–ฅ์ƒ ๋ฐ ๋ฆฌ์†Œ์Šค ์ ˆ๊ฐ

์ฆ‰, ํ•™์Šต๊ณผ ๋ฐฐํฌ์˜ ๊ตฌ์กฐ์  ๋ถ„๋ฆฌ๋ฅผ ๊ฐ€๋Šฅํ•˜๊ฒŒ ํ•œ ๊ธฐ๋ฒ•์ด๋ผ๋Š” ์ ์—์„œ ์‚ฐ์—…์  ์˜๋ฏธ๊ฐ€ ๋งค์šฐ ํฝ๋‹ˆ๋‹ค.

์‹คํ—˜๊ฒฐ๊ณผ

MNIST

MNIST ์‹คํ—˜์—์„œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์„ค์ •์„ ์‚ฌ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค.

  • Teacher ๋ชจ๋ธ: ์—ฌ๋Ÿฌ ๊ฐœ์˜ ๋ชจ๋ธ์„ ์•™์ƒ๋ธ”ํ•œ ๊ณ ์„ฑ๋Šฅ ๋„คํŠธ์›Œํฌ
  • Student ๋ชจ๋ธ: ์ƒ๋Œ€์ ์œผ๋กœ ์ž‘์€ ๋„คํŠธ์›Œํฌ

๊ฒฐ๊ณผ์ ์œผ๋กœ Student ๋ชจ๋ธ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ํŠน์ง•์„ ๋ณด์˜€์Šต๋‹ˆ๋‹ค.

  1. ์ผ๋ฐ˜์ ์ธ hard label ํ•™์Šต ๋Œ€๋น„ ๋” ๋‚ฎ์€ error rate
  2. Teacher์˜ ์„ฑ๋Šฅ์— ๊ทผ์ ‘ํ•œ ์ •ํ™•๋„ ๋‹ฌ์„ฑ
  3. ๊ณผ์ ํ•ฉ ๊ฐ์†Œ ํšจ๊ณผ

ํŠนํžˆ ๋ฐ์ดํ„ฐ๊ฐ€ ์ถฉ๋ถ„ํ•˜์ง€ ์•Š์€ ์ƒํ™ฉ์—์„œ๋„ soft target์„ ์‚ฌ์šฉํ•˜๋ฉด ์ผ๋ฐ˜ํ™” ์„ฑ๋Šฅ์ด ๊ฐœ์„ ๋˜๋Š” ๊ฒฝํ–ฅ์„ ๋ณด์˜€์Šต๋‹ˆ๋‹ค.

์ด๋Š” soft target์ด ์ผ์ข…์˜ regularizer ์—ญํ• ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.


Speech Recognition

์Œ์„ฑ ์ธ์‹ ์‹คํ—˜์—์„œ๋Š” ๋Œ€๊ทœ๋ชจ acoustic model์„ Teacher๋กœ ์‚ฌ์šฉํ•˜์˜€์Šต๋‹ˆ๋‹ค.

 

Teacher๋Š” ๋งค์šฐ ๋ณต์žกํ•˜๊ณ  ํฐ ๋„คํŠธ์›Œํฌ์˜€์œผ๋ฉฐ, ์ง์ ‘ ๋ฐฐํฌํ•˜๊ธฐ์—๋Š” ๋น„ํšจ์œจ์ ์ด์—ˆ์Šต๋‹ˆ๋‹ค.

Distillation์„ ์ ์šฉํ•œ Student ๋ชจ๋ธ์€:

  • ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜ ๊ฐ์†Œ
  • ์ถ”๋ก  ์†๋„ ๊ฐœ์„ 
  • ์ •ํ™•๋„๋Š” Teacher์— ๊ทผ์ ‘

ํŠนํžˆ soft target ๊ธฐ๋ฐ˜ ํ•™์Šต์ด hard label ๊ธฐ๋ฐ˜ ํ•™์Šต๋ณด๋‹ค ์•ˆ์ •์ ์ธ ์ˆ˜๋ ด ํŠน์„ฑ์„ ๋ณด์˜€์Šต๋‹ˆ๋‹ค.

์ด๋Š” KD๊ฐ€ classification๋ฟ ์•„๋‹ˆ๋ผ sequence ๊ธฐ๋ฐ˜ ๋ฌธ์ œ์—๋„ ์ ์šฉ ๊ฐ€๋Šฅํ•จ์„ ๋ณด์—ฌ์ฃผ๋Š” ์‚ฌ๋ก€์ž…๋‹ˆ๋‹ค.

์ฐธ๊ต์ž๋ฃŒ :

[๋…ผ๋ฌธ ๋ฆฌ๋ทฐ] Distilling the Knowledge in a Neural Networkโ€‹

728x90

+ Recent posts