DIAYN (Diversity is All You Need)

Diversity is All You Need : Learning Skills without a Reward Function

(https://arxiv.org/abs/1802.06070)

์œ„ ๋…ผ๋ฌธ์—์„œ ๋ฐ”๋ผ๋ณด๊ณ ์ž ํ•˜๋Š” ์ ์€ ๋˜‘๋˜‘ํ•œ ์‚ฌ๋žŒ๋“ค์€ ์ง€์‹œ์ž์˜ ์ง€์‹œ๊ฐ€ ์—†์ด๋„ ์Šค์Šค๋กœ ๋‹ค์–‘ํ•œ ์‹œ๋„๋“ค์„ ํ†ตํ•˜์—ฌ ์Šค์Šค๋กœ ์œ ์šฉํ•œ ๊ธฐ์ˆ ๋“ค์„ ํ•™์Šตํ•œ๋‹ค๋Š” ์ผ์ƒ์˜ ์ผ๋“ค์„ ๊ฐ•ํ™”ํ•™์Šต์— ์ ์šฉํ•ด ๋ณด์ž๋Š” ์˜๋„๋ฅผ ๊ฐ€์ง€๊ณ  ์ž‘์„ฑ๋˜์—ˆ๋‹ค.
์ด๋ฅผ ๊ฐ•ํ™”ํ•™์Šต์œผ๋กœ ํ’€์–ด๋ณด์ž๋ฉด ์–ด๋–ป๊ฒŒ ํ•˜๋ฉด ๋ณด์ƒ์ด ์—†์ด๋„ ์œ ์šฉํ•œ ๊ธฐ์ˆ ๋“ค์„ ์Šค์Šค๋กœ ๋ฐฐ์›Œ์„œ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์„๊นŒ ์ด๋ฅผ ๊ตฌํ˜„ํ•˜๊ธฐ ์œ„ํ•ด์„œ ํ•„์š”ํ•œ ๊ฑด ๋ฌด์—‡์ผ๊นŒ? ์ด๋‹ค.
reward function์ด ์—†์ด useful skills๋ฅผ ํ•™์Šตํ•œ๋‹ค. ์ด๋ฅผ ๊ธฐ์ˆ ์ ์œผ๋กœ ํ’€์–ด๋ณด์ž๋ฉด maximum entropy policy๋ฅผ ์ด์šฉํ•˜์—ฌ information theoretic objective๋ฅผ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ํ†ตํ•ด skills๋ฅผ ํ•™์Šตํ•˜๋Š” ๋ฐฉ๋ฒ•๋ก ์ด๋ผ๊ณ  ๋ณด๋ฉด ๋œ๋‹ค.
์—ฌ๊ธฐ์„œ ์ค‘์š”ํ•œ ํ‚ค์›Œ๋“œ๋Š” maximum entropy policy์ด๋‹ค.
๋”ฅ๋Ÿฌ๋‹์—์„œ ๊ฐ€์žฅ ์ค‘์š”ํ•œ ๊ฒƒ์€ Data์™€ Optimization์ธ๋ฐ ์ตœ์ ํ™”์— maximum entropy policy๋ฅผ ์ ์šฉํ•œ๋‹ค๋Š” ๊ฑด noise๋ฅผ ํ†ตํ•œ smooth ํšจ๊ณผ์˜ ์ ์šฉ์„ ํ†ตํ•˜์—ฌ ๋‹จ๊ธฐ์  ๋ณด์ƒ๋ณด๋‹ค๋Š” ์žฅ๊ธฐ์ ์ธ ๋ณด์ƒ์— ๋Œ€ํ•œ ๊ธฐ๋Œ€๊ฐ’์„ ๋†’์ด๊ธฐ ์œ„ํ•จ์ด๋‹ค.
Entropy๋ฅผ ์ •๋ณด ์ด๋ก ์—์„œ๋Š” ์ด๋ฒคํŠธ์— ํฌํ•จ๋œ ์˜ˆ์ƒ ์ •๋ณด ๋น„ํŠธ ์ˆ˜๋กœ ์ •์˜ํ•  ์ˆ˜ ์žˆ๊ณ  ์œ„ ๋…ผ๋ฌธ์—์„œ์˜ Maximum Entropy๋ผ๋Š” ๊ฒƒ์€ ์—์ด์ „ํŠธ๊ฐ€ ๊ฐ€์žฅ ๋†’์€ ๋ณด์ƒ ํ•ฉ๊ณ„์™€ ์žฅ๊ธฐ Entropy ํ•ฉ๊ณ„๋ฅผ ๋ฐ›์„ ์ˆ˜ ์žˆ๋Š” ๊ฐ€์žฅ ์˜ฌ๋ฐ”๋ฅธ ์ž‘์—…์„ ์„ ํƒํ•˜๊ธฐ ์œ„ํ•ด Policy ์ตœ์ ํ™”๋ฅผ ์ˆ˜ํ–‰ํ•œ๋‹ค๋Š”๋ฐ ์žˆ๋‹ค๊ณ  ๋ณด๋ฉด ๋œ๋‹ค.
์ด ์ตœ์ ํ™” ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด Agent๋Š” ๋” ๋งŽ์€ ๊ฒƒ์„ ํƒ์ƒ‰ํ•˜๊ณ  Local Optimization์„ ํ”ผํ•  ์ˆ˜ ์žˆ๋‹ค.
Maximum Entropy ์›๋ฆฌ์˜ ๋ชฉํ‘œ๋Š” Maximum Entropy์˜ ๋ถ„ํฌ๋ฅผ ์ฐพ๋Š” ๊ฒƒ์ด๋‹ค.
๋งŽ์€ ๊ฐ•ํ™”ํ•™์Šต ์•Œ๊ณ ๋ฆฌ์ฆ˜์—์„œ Agent๋Š” Local Optimization์— ์ˆ˜๋ ดํ•  ์ˆ˜ ๋ฐ–์— ์—†๋Š” ๊ตฌ์กฐ์  ํ•œ๊ณ„๋ฅผ ๊ฐ€์ง€๊ณ  ์žˆ๋‹ค.
์ด๋ฅผ ๊ทน๋ณตํ•˜๊ธฐ ์œ„ํ•ด ๋ชฉ์  ํ•จ์ˆ˜์— Maximum Entropy๋ฅผ ์ถ”๊ฐ€ํ•˜์—ฌ Agent๊ฐ€ Maximum Entropy๋ฅผ ๊ฐ–๋Š” ๋ถ„ํฌ๋ฅผ ๊ฒ€์ƒ‰ํ•  ์ˆ˜ ์žˆ๋‹ค. ์•ž ์„œ๋„ ๋งํ•œ๋ฐ”์™€ ๊ฐ™์ด Maximum Entropy์˜ ๋ชฉํ‘œ๋Š” ๊ฐ€์žฅ ๋†’์€ ๋ˆ„์  ๋ณด์ƒ๊ณผ Maximum Entropy๋ฅผ ๋‹ฌ์„ฑํ•  ์ˆ˜ ์žˆ๋Š” ์ตœ์ ์˜ ์ •์ฑ…์„ ํ•™์Šตํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ์ด ์ตœ์ ์˜ ์ •์ฑ…์€ ๋ฐ”๋กœ ์žฅ๊ธฐ ๋ณด์ƒ ๋ฐ ์žฅ๊ธฐ ์—”ํŠธ๋กœํ”ผ์— ๋Œ€ํ•œ ๊ฐ€์žฅ ๋†’์€ ๊ธฐ๋Œ€์น˜์ด๋‹ค.

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” pretrained skills๊ฐ€ ์–ด๋–ป๊ฒŒ downstream tasks์— ๋Œ€ํ•ด ์ข‹์€ parameter initialization์„ ์ œ๊ณตํ•  ์ˆ˜ ์žˆ๋Š”์ง€์™€, complexํ•˜๊ณ  sparse reward tasks๋ฅผ ํ’€๊ธฐ ์œ„ํ•ด ๊ณ„์ธต์  ๊ตฌ์„ฑ์„ ํ†ตํ•ด ๋ฌธ์ œ๋ฅผ ํ’€์–ด๊ฐ€๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ฃผ๊ณ  ์žˆ๋‹ค.

Downstream task๋Š” pretrained model ํ˜น์€ feature๋ฅผ supervised-learning task์— ์ ์šฉ์‹œํ‚ค๋Š” ๊ฒƒ์ด๋‹ค.
Transfer learning์—์„œ ๋Œ€๊ทœ๋ชจ ๋ฐ์ดํ„ฐ ์…‹์œผ๋กœ ๋ชจ๋ธ์„ ํ•™์Šตํ•œ ํ›„, ํ•ด๋‹น ๋ชจ๋ธ์— ๋Œ€ํ•œ neural network architecture๋ฅผ ์ผ๋ถ€ ๋ณ€ํ˜•ํ•˜์—ฌ ๋’ท๋‹จ์˜ ๋ถ€๋ถ„๋งŒ ์ผ๋ถ€ ํ•™์Šต์„ ํ†ตํ•˜์—ฌ ๋‹ค์–‘ํ•œ ์„œ๋น„์Šค๋ฅผ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋Š” ๊ฒƒ์„ downstream task๋ผ๊ณ  ํ•œ๋‹ค. ์ด๋Š” ๋ณต์žกํ•œ ๊ฒŒ์ž„ ๊ฐ™์€ ๊ฒฝ์šฐ ์ „๋žต์„ ํ•™์Šต์‹œํ‚ค๊ธฐ ์œ„ํ•ด ์‚ฌ๋žŒ์ด ์ง์ ‘ reward๋ฅผ ์„ค๊ณ„ํ•˜๋Š” ๊ฒƒ์ด ์–ด๋ ต๋‹ค.
reward engineering๋‹จ๊ณ„์—์„œ ์—ฐ๊ตฌ์ž์˜ ์˜๋„๊ฐ€ ๋“ค์–ด๊ฐ€๋”๋ผ๋„ ๊ทธ ์˜๋„๋ฅผ ํ•™์Šต์— ์ œ๋Œ€๋กœ ๋ฐ˜์˜ํ•˜๋Š” ๊ฒƒ์ด ๊นŒ๋‹ค๋กญ๊ณ , ์‚ฌ๋žŒ์˜ ๊ฐœ์ž…์—†์ด agent๋ผ๋ฆฌ ํ•™์Šต์„ ํ•˜๋Š” ๊ณผ์ •์—์„œ agent๊ฐ€ ์ฐฝ์˜์ ์ธ ์ „๋žต์„ ์ฐพ์•„๋‚ด๋Š” ๊ฒฝ์šฐ๋„ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋Œ€์ „ํ˜• ๊ฒŒ์ž„์€ ์ฃผ๋กœ ๊ฒŒ์ž„์ด ๋๋‚œ ๋’ค ์ŠนํŒจ์— ๋”ฐ๋ผ reward๋ฅผ ์ œ๊ณตํ•œ๋‹ค.
Agent์˜ ํ•™์Šต episode์—์„œ ๋‹จ ํ•œ๋ฒˆ์˜ reward๋งŒ์ด ์ฃผ์–ด์ง€๊ธฐ ๋•Œ๋ฌธ์— ์œ„์™€ ๊ฐ™์€ ๊ฒฝ์šฐ์—์„œ๋Š” sparse reward problem์ด ์กด์žฌํ•œ๋‹ค. ์ด ๋…ผ๋ฌธ์—์„œ๋Š” ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•œ ๋ฐฉ๋ฒ•๋ก ์œผ๋กœ ๊ณ„์ธต์  ๊ตฌ์„ฑ์„ ํ†ตํ•ด ๋ฌธ์ œ๋ฅผ ํ’€์–ด๊ฐ€๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ฃผ๋ ค ํ•œ๋‹ค.

Introduction
๊ฐ•ํ™”ํ•™์Šต์€ playing games, controlling robots, navigating complex environment๋ฅผ ํฌํ•จ, ๋‹ค์–‘ํ•œ reward driven skills๋ฅผ ํšจ์œจ์ ์œผ๋กœ ํ•™์Šตํ•˜๋Š” ๋ฐฉ๋ฒ•๋ก ์ด๋‹ค. ํ•˜์ง€๋งŒ intelligent creatures๋Š” supervision์—†์ด๋„ useful skills๋ฅผ ์ž˜ ํ•™์Šตํ•˜๊ณ  ํ™˜๊ฒฝ์„ exploreํ•  ์ˆ˜ ์žˆ๋‹ค.
๋”ฐ๋ผ์„œ intelligent creature๋Š” ์ดํ›„์— ๋‹ค๋ฅธ goal์ด ์ƒ๊ฒผ์„ ๋•Œ, skills๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๋น ๋ฅด๊ณ  ํšจ์œจ์ ์œผ๋กœ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋‹ค.
reward์—†์ด skills๋ฅผ ํ•™์Šตํ•˜๋Š” practical applications๊ฐ€ ์žˆ๋‹ค.
์ด๋Š” sparse rewards๊ฐ€ ์กด์žฌํ•˜๋Š” ํ™˜๊ฒฝ์—์„œ agent๊ฐ€ goal state์— ๋ฌด์ž‘์œ„๋กœ ๋„๋‹ฌํ•  ๋•Œ๊นŒ์ง€์˜ ๋ณด์ƒ์ด ์—†๋‹ค.
supervision์—†์ด useful skills๋ฅผ ํ•™์Šตํ•˜๋Š” ๊ฒฝ์šฐ exploration์˜ ์–ด๋ ค์›€์„ ํ•ด๊ฒฐํ•˜๋Š”๋ฐ์— ๋„์›€์„ ์ฃผ๋Š”๋ฐ ์ด๋Ÿฐ ๋‚ด์—ญ์— ๋Œ€ํ•˜์—ฌ ๋ณด์ƒ์„ ์„ค๊ณ„ํ•˜๋Š” ๊ฒƒ์€ ๋งค์šฐ ์–ด๋ ต๊ณ  ์ด๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ agent๋กœ ๋ถ€ํ„ฐ ์›ํ•˜๋Š” ํ–‰๋™์„ ์œ ๋„ํ•˜๋Š” reward function์„ ์„ค๊ณ„ํ•˜๋Š” ๊ฒƒ์€ ๋”์šฑ ์–ด๋ ต๋‹ค. ๊ทธ๋ ‡๊ธฐ ๋•Œ๋ฌธ์— ์ต์ˆ™ํ•˜์ง€ ์•Š์€ ํ™˜๊ฒฝ์ด ์ฃผ์–ด์กŒ์„ ๋•Œ, agent๊ฐ€ ์–ด๋–ค tasks๋ฅผ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋Š”์ง€ ๊ฒฐ์ •ํ•˜๋Š” ๊ฒƒ์ด challenging์ด๋‹ค. ๊ทธ๋ฆฌ๊ณ  ๋ณด์ƒ์—†์ด useful skills๋ฅผ ํ•™์Šตํ•˜๋Š” ๊ฒƒ ๋˜ํ•œ ๋งค์šฐ ์–ด๋ ค์šด ๋ฌธ์ œ์ด๋‹ค.
๊ฒฐ๋ก ์ ์œผ๋กœ ๋งํ•˜์ž๋ฉด ๋ณด์ƒ์„ ๋””์ž์ธํ•˜๋Š” ๊ฒƒ์€ ๋งค์šฐ ์–ด๋ ต์ง€๋งŒ ๋ณด์ƒ์—†์ด ํ•™์Šตํ•˜๋Š” ๊ฒƒ ๋˜ํ•œ ๋งค์šฐ ์–ด๋ ค์šด ๋ฌธ์ œ์ด๋‹ค.
์œ„์—์„œ ๊ณ„์† ์–ธ๊ธ‰ํ•˜๋Š” Skill์ด๋ž€ ํ™˜๊ฒฝ์˜ state๋ฅผ consistent way(์ผ๊ด€๋œ ๋ฐฉ์‹)๋กœ ๋ณ€ํ™”์‹œํ‚ค๋Š” latent conditioned policy์ด๋‹ค. reward function์„ ๋ชจ๋ฅผ ๋•Œ๋„ ‘set’์˜ utility๋ฅผ ์ตœ๋Œ€ํ™”ํ•˜์—ฌ skills์˜ set์„ ํ•™์Šตํ•˜๊ณ  ์‹ถ๋‹ค.
์ƒํ˜ธ ์ •๋ณด(mutual information)์— ๊ธฐ๋ฐ˜ํ•œ ๊ฐ„๋‹จํ•œ ๋ชฉํ‘œ๋ฅผ ํ†ตํ•ด ๊ฐ•ํ™”ํ•™์Šต agent๊ฐ€ ์ด๋Ÿฐ skills๋ฅผ ์ž์œจ์ ์œผ๋กœ ๋ฐœ๊ฒฌํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์„ ๋ณผ ์ˆ˜ ์žˆ๋‹ค๊ณ  ํ•œ๋‹ค. ์ด๋Ÿฐ skills๋Š” hierarchical reinforcement learning์ด๋‚˜ imitation learning๋“ฑ์— ํ™œ์šฉ๋˜๊ณ  ์žˆ๋‹ค.
๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” useful skills๋ฅผ ์–ป๊ธฐ ์œ„ํ•ด, ๊ฐ€๋Šฅํ•œ ํ–‰๋™์˜ set๋“ค์„ ์ตœ๋Œ€ํ•œ ์ปค๋ฒ„ํ•  ์ˆ˜ ์žˆ๋„๋ก skills๋ฅผ ํ›ˆ๋ จํ•ด์•ผ ํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ•œ๋‹ค. ํ›ˆ๋ จ๋ฐฉ๋ฒ•์œผ๋กœ๋Š” objective๋กœ์„œ์˜ skills ์‚ฌ์ด์˜ ์ฐจ๋ณ„์„ฑ(discriminability)์„ ์ด์šฉํ•˜๋Š” ๊ฒƒ์ด๋‹ค.
๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” discriminabilityํ•˜๋ฉด์„œ ๋‹ค์–‘ํ•œ skills๋ฅผ ํ•™์Šตํ•ด์•ผ ํ•œ๋‹ค.

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” ์ด 5๊ฐ€์ง€์˜ contribution์ด ์žˆ๋‹ค.

1> Maximum entropy policy๋กœ ์ •๋ณด์ด๋ก ์˜ ๋ชฉํ‘œ๋ฅผ ๊ทน๋Œ€ํ™”ํ•˜๋Š” ๊ฒƒ์œผ๋กœ ์ฐจ๋ณ„๋œ goal์„ ๊ณต์‹ํ™”ํ•œ๋‹ค. ์ฆ‰ ์—ฌ๋Ÿฌ ๊ฐ•ํ™”ํ•™์Šต ์—ฐ๊ตฌ์—์„œ bench mark๋กœ ์‚ฌ์šฉํ•˜๋Š” ํ™˜๊ฒฝ๋“ค์—์„œ๋„ true reward๋ฅผ ๋ฐ›์ง€ ์•Š๊ณ  ์ž˜ ํ•™์Šตํ•˜๋Š” ๊ฒƒ์„ ๋ณด์—ฌ์ค€๋‹ค.
2> ์–ด๋– ํ•œ ๋ณด์ƒ์—†์ด useful skills๋ฅผ ํ•™์Šตํ•˜๋Š” ๋ฐฉ๋ฒ•
๋‹ฌ๋ฆฌ๊ธฐ, ์ ํ•‘๊ณผ ๊ฐ™์€ ๋‹ค์–‘ํ•œ skills์˜ unsupervised emergence(๋น„์ง€๋„ ์ถœํ˜„)์—์„œ ๊ฐ„๋‹จํ•œ ํƒํ—˜ ๋ชฉํ‘œ์˜ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์—ฌ์ค€๋‹ค.
3> Hierarchical RL์— ๋Œ€ํ•ด ํ•™์Šต๋œ skills๋ฅผ ์‚ฌ์šฉํ•˜๊ธฐ ์œ„ํ•œ ๊ฐ„๋‹จํ•œ ๋ฐฉ๋ฒ•์„ ์ œ์•ˆํ•˜๊ณ  ์ด๋Ÿฐ ๋ฐฉ๋ฒ•๋“ค์ด challenging tasks๋ฅผ ํ•ด๊ฒฐํ•จ์„ ๋ณด์ด๊ณ  ์žˆ๋‹ค.
๊ฐ state๋งˆ๋‹ค ์ ์ ˆํ•œ action์„ ์ทจํ•  ๊ฒƒ์ด๋‹ค. ์ฆ‰ ์‚ฌ๋žŒ์€ ์•ก์…˜์„ ์ •ํ•จ์— ์žˆ์–ด์„œ hierarchy๊ฐ€ ์žˆ์–ด์„œ ๋จผ์ € ์ƒ์œ„ ์•ก์…˜์„ ๊ฒฐ์ •ํ•˜๊ณ  ์ƒ์œ„ ์•ก์…˜์„ ์ทจํ•˜๊ธฐ ์œ„ํ•ด ํ•˜์œ„ ์•ก์…˜๋“ค์„ ๊ฒฐ์ •ํ•˜๊ฒŒ ๋œ๋‹ค. ์ด๊ฒƒ์ด ๋ฐ”๋กœ Hierarchical RL(HRL)์ด๋‹ค.
4> Skills๊ฐ€ ์–ผ๋งˆ๋‚˜ ๋นจ๋ฆฌ ์ƒˆ๋กœ์šด task์— ์ž˜ ์ ์‘ํ•˜๋Š”์ง€๋ฅผ ๋ณด์—ฌ์ค€๋‹ค.
5> ๋ฐœ๊ฒฌ๋œ skills๊ฐ€ imitation learning์— ์–ด๋–ป๊ฒŒ ์‚ฌ์šฉ๋  ์ˆ˜ ์žˆ๋Š”์ง€๋ฅผ ๋ณด์—ฌ์ค€๋‹ค.
imitation learning์„ ์‚ฌ์šฉํ•˜๋Š” ์ด์œ ๋Š” ์ผ๋ฐ˜์ ์ธ MDP์—์„œ๋Š”, ์ข‹์€ policy๋ฅผ ์ฐพ๊ธฐ ์œ„ํ•ด์„œ๋Š” ๊ต‰์žฅํžˆ ๋งŽ์€ ์–‘์˜ sample๋“ค์ด ํ•„์š”ํ•˜๋‹ค. ์˜ˆ๋ฅผ ๋“ค์ž๋ฉด, DQN๊ฐ™์€ ๊ฒฝ์šฐ ์ตœ๋Œ€ํ•œ ๋งŽ์€ ์–‘์˜ sample์„ ๊ฐ€์ง€๊ณ  ์˜ค๋žœ training์„ ํ•ด์•ผ๋งŒ ์ข‹์€ ์„ฑ๋Šฅ์„ ๋‚ผ ์ˆ˜ ์žˆ์—ˆ๋‹ค. ํ•˜์ง€๋งŒ, ํ˜„์‹ค ์ƒ์—์„œ๋Š” ์‹ค์ œ ๊ฐ•ํ™” ํ•™์Šต์„ ์ˆ˜ํ–‰ํ•  ๋งŒํ•œ ๋งŽ์€ ์–‘์˜ sample๋“ค์„ ์–ป๊ธฐ๋ž€ ์‰ฝ์ง€ ์•Š๋‹ค.
๋งŒ์•ฝ ์šฐ๋ฆฌ๊ฐ€ ์šฐ์ฃผ์„ ์„ ๋ฐœ์‚ฌํ•˜๋Š” ๊ฐ•ํ™”ํ•™์Šต agent๋ฅผ ๋งŒ๋“ค์–ด์•ผ ํ•œ๋‹ค๊ณ  ๊ฐ€์ •ํ•ด ๋ณด์ž. ์ˆ˜์—†์ด ๋งŽ์€ ์šฐ์ฃผ์„  ๋ฐœ์‚ฌ๋ฅผ ์‹คํŒจํ•ด์•ผ๋งŒ ์ง„์งœ ์ œ๋Œ€๋กœ ๋œ ์šฐ์ฃผ์„  ๋ฐœ์‚ฌ๋ฅผ ๋ณผ ์ˆ˜ ์žˆ์„ ๊ฒƒ์ด๋‹ค. ์ด๋ ‡๊ฒŒ ๋˜๋ฉด ์ฒœ๋ฌธํ•™์ ์ธ ๋น„์šฉ์ด ์†Œ์š”๋  ๊ฒƒ์ด๊ณ , ์‚ฌ์‹ค์ƒ ์ด๋Ÿฐ ๋ฐฉ์‹์˜ RL๋กœ๋Š” ์šฐ์ฃผ์„  ๋ฐœ์‚ฌ๋Š” ๋ถˆ๊ฐ€๋Šฅ ํ•˜๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค. ์ž์œจ์ฃผํ–‰ ์ž๋™์ฐจ์˜ ๊ฒฝ์šฐ๋„ ๋งˆ์ฐฌ๊ฐ€์ง€๋‹ค. ์Šค์Šค๋กœ ์šด์ „ํ•˜๋Š” ์ž๋™์ฐจ๋ฅผ ๋งŒ๋“ค์–ด์•ผ ํ•˜๋Š”๋ฐ, ์ˆ˜์—†์ด ๋งŽ์€ ์‚ฌ๊ณ  ์ดํ›„์—์•ผ ์šด์ „์„ ์ œ๋Œ€๋กœ ํ•  ์ˆ˜ ์žˆ๋‹ค๋ฉด, ์•„๋ฌด๋„ ๊ทธ ๋น„์šฉ์„ ๊ฐ์ˆ˜ํ•˜๋ ค ํ•˜์ง€ ์•Š์„ ๊ฒƒ์ด๋‹ค.
๊ทธ๋ ‡๋‹ค๋ฉด sample์˜ ๊ฐฏ์ˆ˜๋ฅผ ์ค„ ์ผ ์ˆ˜ ์žˆ์ง€ ์•Š์„๊นŒ? ๊ทธ๋ƒฅ ์•„๋ฌด๊ฒƒ๋„ ์•Œ๋ ค์ฃผ์ง€ ์•Š์€ ์ƒํƒœ๋กœ Optimal Policy๋ฅผ ์–ป๊ธฐ๋ฅผ ๊ธฐ๋Œ€ํ•˜๊ธฐ ๋ณด๋‹ค๋Š”, ์ด ๊ฐ•ํ™” ํ•™์Šต ๊ณผ์ •์„ ๋„์™€์ค„ ์ถ”๊ฐ€์ ์ธ ์ •๋ณด๋‚˜ ๊ตฌ์กฐ๋“ค์„ ์•Œ๋ ค์ค€ ๋’ค์— ํ›ˆ๋ จ์„ ํ•˜๋ฉด ๋˜์ง€ ์•Š์„๊นŒ ๋ฐ”๋กœ ์ด๋Ÿฐ ๊ฐ€์ •์—์„œ ์ถœ๋ฐœํ•œ ๊ฒƒ์ด imitation learning์ด๋‹ค.
์ง€๊ธˆ๊นŒ์ง€์˜ ๊ฐ•ํ™”ํ•™์Šต์€ Reward๋ฅผ ํ†ตํ•ด Agent๋ฅผ ํ•™์Šตํ•˜์˜€๋‹ค. DQN, Q-learning, Monte Carlo ๋“ฑ๋“ฑ ๋ชจ๋‘ ๋‹ค reward function์„ ์‚ฌ์šฉํ•˜์—ฌ ์ตœ๋Œ€์˜ reward๋ฅผ ์–ป์„ ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋Š” ๊ฒƒ์ด ์ฃผ์š” ํฌ์ธํŠธ์˜€๋‹ค. ์ด ๋ฐฉ์‹์€ ๋งค์šฐ ๊ฐ„๋‹จํ•œ ๋ฐฉ์‹์œผ๋กœ ํ›ˆ๋ จ์ด ๊ฐ€๋Šฅํ•˜๋‹ค๋Š” ์ ์—์„œ ์ข‹์ง€๋งŒ, ๋„ˆ๋ฌด ๋งŽ์€ sample์„ ์š”๊ตฌํ•œ๋‹ค๋Š” ๋‹จ์ ์ด ์žˆ๋‹ค.
์ž์œจ ์ฃผํ–‰ ์ž๋™์ฐจ์˜ reward๋ฅผ ์‚ฐ์ •ํ•˜๋ ค๋ฉด ์–ด๋–ป๊ฒŒ ํ•ด์•ผํ• ๊นŒ? ๋งŒ์•ฝ ์ด reward๊ฐ’์„ ์‚ฌ๊ณ ๊ฐ€ ๋‚˜๋ฉด -10, ์‚ฌ๊ณ ๊ฐ€ ๋ฐœ์ƒํ•˜์ง€ ์•Š์œผ๋ฉด +0.1 ์ด๋Ÿฐ์‹์œผ๋กœ ์„ค๊ณ„ํ•œ๋‹ค๋ฉด ์–ด๋–จ๊นŒ? ๊ทธ๋ ‡๊ฒŒ ๋˜๋ฉด ์„ค๊ณ„๊ฐ€ ๋œ agent๊ฐ€ ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ์‚ฌ๊ณ ๊ฐ€ ๋‚˜๊ณ  ์–ด๋–ป๊ฒŒ ํ•ด์•ผ ์‚ฌ๊ณ ๊ฐ€ ๋‚˜์ง€ ์•Š์„์ง€ ์•Œ์•„๋‚ด๋Š” ๊ฒƒ์— ๋งŽ์€ ๋น„์šฉ์ด ๋“ค์–ด๊ฐˆ ๊ฒƒ์ด๋‹ค. ๊ทธ๋ ‡๋‹ค๋ฉด ๋ชจ๋“  ์ƒํ™ฉ์— ๋Œ€ํ•˜์—ฌ ์ ์ ˆํ•œ reward๋ฅผ ๋Œ€์ž…ํ•ด์ฃผ๋ฉด ์–ด๋–จ๊นŒ? ์šฐ์„  ์ด ๋ฐฉ์‹์€ ๋„ˆ๋ฌด ์˜ค๋žœ ์‹œ๊ฐ„์ด ๊ฑธ๋ฆฌ๊ณ , ์ด๋ ‡๊ฒŒ reward๋ฅผ ์ •ํ•ด์ค€๋‹ค๊ณ  ํ•ด๋„ reward์˜ ์ƒํƒœ๊ฐ€ ๋งค์šฐ ๋ถˆ์•ˆ์ •ํ•ด์งˆ ์ˆ˜ ์žˆ๋‹ค. ์ด๋ฅผ ๋ณด์™„ํ•˜๊ธฐ ์œ„ํ•œ ๋Œ€์•ˆ์ฑ…์œผ๋กœ, ๋ฐ”๋กœ reward๋ฅผ demonstration, ์ฆ‰ ์‹ค์ œ๋กœ ์–ด๋–ป๊ฒŒ ํ•˜๋Š”์ง€ ๋ณด์—ฌ์ฃผ๋ฉด์„œ reward๋ฅผ implicitํ•˜๊ฒŒ ์ฃผ๋Š” ๊ฒƒ์ด๋‹ค. ์ด๋ ‡๊ฒŒ demonstration์œผ๋กœ reward๋ฅผ ์‚ฐ์ •ํ•˜๋ ค๋ฉด ์–ด๋–ป๊ฒŒ ํ•ด์•ผํ• ๊นŒ? ๋ฐ”๋กœ ํ•™์Šต์‹œํ‚ฌ ๋ถ„์•ผ์˜ ์ „๋ฌธ๊ฐ€์™€ ํ•จ๊ป˜ demonstration trajectory๋ฅผ ๋งŒ๋“ค์–ด ํ•™์Šต์‹œํ‚ค๋Š” ๊ฒƒ์ด๋‹ค.
์ž์œจ ์ฃผํ–‰ ์ž๋™์ฐจ๋ฅผ ๋งŒ๋“ ๋‹ค๊ณ  ํ•˜๋ฉด, ์šด์ „์„ ๋งค์šฐ ์ž˜ํ•˜๋Š” ์‚ฌ๋žŒ์„ ๋ฐ๋ ค์™€์„œ ์‹ค์ œ๋กœ ํ•œ๋ฒˆ ์šด์ „ํ•˜๋Š” ๋ชจ์Šต์„ ๋ณด๋Š” ๊ฒƒ์ด๋‹ค.
๊ทธ๋ ‡๊ฒŒ ์–ป์€ State/Action Sequence๋“ค์„ ๋ฐ”ํƒ•์œผ๋กœ ๊ฐ•ํ™” ํ•™์Šต agent๋ฅผ ํ•™์Šต์‹œํ‚ค๋ฉด ๋œ๋‹ค.
imitation learning๋ฐฉ์‹์€ reward๋ฅผ ์ผ์ผํžˆ ๋ถ€์—ฌํ•˜๊ฑฐ๋‚˜ ํŠน์ • policy๋ฅผ ๋”ฐ๋ฅด๋„๋ก ํ•˜๊ฒŒ ํ•˜๋ ค๋Š” ๊ฒƒ์ด ์•„๋‹ ๊ฒฝ์šฐ์— ํšจ์œจ์ ์ด๋‹ค.
์ด๋Ÿฐ imitation learning์˜ ๊ธฐ๋ณธ์ ์ธ setting์€ ์šฐ์„  ์ž…๋ ฅ๊ฐ’์€ ์ง€๊ธˆ๊นŒ์ง€์™€ ๋น„์Šทํ•˜๊ฒŒ State์™€ Action space๋กœ ์ด๋ค„์ ธ ์žˆ๊ณ , Transition model P๊ฐ€ ์ฃผ์–ด์ง„๋‹ค. ๋‹ค๋ฅธ ์ ์€, reward function R์€ ์ฃผ์–ด์ง€์ง€ ์•Š๋Š” ๋Œ€์‹  (s0, a0, s1, a1, …..)๊ณผ ๊ฐ™์€ demonstration์ด ์ฃผ์–ด์ง„๋‹ค.

Diversity is all you need

๋ณธ ๋…ผ๋ฌธ์€ unsupervised RL ํŒจ๋Ÿฌ๋‹ค์ž„์„ agent๊ฐ€ unsupervised “exploration” stage์— ์ด์–ด์„œ supervised stage๋„ ํ—ˆ์šฉ๋œ work์—์„œ ๊ณ ๋ คํ•œ๋‹ค. Unsupervised stage์˜ ๋ชฉ์ ์€ ๊ถ๊ทน์ ์œผ๋กœ supervised stage์˜ task reward๋ฅผ ์‰ฝ๊ฒŒ ์ตœ๋Œ€ํ™”ํ•˜๋Š” skills๋ฅผ ํ•™์Šตํ•˜๋Š” ๊ฒƒ์ด๋‹ค.
ํŽธ๋ฆฌํ•˜๊ฒŒ๋„ ์ด๋Ÿฌํ•œ skills๋Š” tasks์— ๋Œ€ํ•œ ์‚ฌ์ „ ์ง€์‹์—†์ด skills๋ฅผ ํ•™์Šตํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋งŽ์€ ๋‹ค๋ฅธ tasks์—์„œ ์‚ฌ์šฉ๋  ์ˆ˜ ์žˆ๋‹ค.

How it works

unsupervised skill์˜ ๋ฐœ๊ฒฌ์„ ์œ„ํ•œ ๋ฐฉ๋ฒ•์€ ์„ธ ๊ฐ€์ง€ ์•„์ด๋””์–ด๋กœ ๊ตฌ์„ฑ๋œ๋‹ค.
1> Agent๊ฐ€ visitํ•˜๋Š” state์— ์˜ํ•ญ์„ ์ฃผ๋Š” skill์ด ์žˆ์–ด์•ผ ํ•œ๋‹ค. ( for skills to be useful )
    ๋‹ค๋ฅธ skills๋Š” ์„œ๋กœ ๋‹ค๋ฅธ states์— ๋ฐฉ๋ฌธํ•ด์•ผ๋งŒ ํ•˜๋ฉฐ, ๊ทธ๋ž˜์„œ ๊ตฌ๋ถ„๋  ์ˆ˜ ์žˆ๋‹ค.
2> Skills๋ฅผ ๊ตฌ๋ถ„ํ•˜๊ธฐ ์œ„ํ•ด action์ด ์•„๋‹Œ state๋ฅผ ์ด์šฉํ•œ๋‹ค.
    ํ™˜๊ฒฝ์— ์˜ํ–ฅ์„ ์ฃผ์ง€ ์•Š๋Š” actions์€ outside observer์—๊ฒŒ ๋ณด์ด์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
    ์˜ˆ๋ฅผ ๋“ค์–ด, outside observer๋Š” ์ปต์ด ์›€์ง์ด์ง€ ์•Š์œผ๋ฉด ์ปต์„ ์žก์„ ๋•Œ ๋กœ๋ด‡ ํŒ”์ด ์–ผ๋งˆ๋‚˜ ๋งŽ์€ ํž˜์„ ๊ฐ€ํ•˜๋Š”์ง€ ์•Œ ์ˆ˜ ์—†๋‹ค.
    ๋งˆ์ง€๋ง‰์œผ๋กœ, ๊ฐ€๋Šฅํ•œ ํ•œ ๋žœ๋คํ•˜๊ฒŒ ํ–‰๋™ํ•˜๋Š” skills์„ ํ•™์Šตํ•˜๋Š” ๊ฒƒ์„ ํ†ตํ•ด skills์ด ๊ฐ€๋Šฅํ•œ ๋‹ค์–‘ํ•ด์ง€๋„๋ก ํ•œ๋‹ค.
3> ๊ตฌ๋ถ„ ๊ฐ€๋Šฅํ•œ ๋†’์€ entropy๋ฅผ ๊ฐ€์ง„ skills๋Š” ๋‹ค๋ฅธ skills๋กœ๋ถ€ํ„ฐ ๋ฉ€๋ฆฌ ๋–จ์–ด์ง„ state space์˜ ์ผ๋ถ€๋ฅผ ๊ผญ explorationํ•ด์„œ action์˜ ๋žœ๋ค์„ฑ์ด ๊ตฌ๋ณ„ํ•  ์ˆ˜ ์—†๋Š” state๋กœ ์ด๋Œ์ง€ ์•Š๋„๋ก ํ•œ๋‹ค.

Objective๋ฅผ ์œ„ํ•œ notation์€ ์ •๋ณด์ด๋ก ์œผ๋กœ๋ถ€ํ„ฐ ๊ฐ€์ ธ์™€ ์‚ฌ์šฉํ•œ๋‹ค.
– ๊ณ ์ •๋œ Z๋ฅผ ์กฐ๊ฑด์œผ๋กœ ํ•˜๋Š” policy๋ฅผ skill์ด๋ผ๊ณ  ํ•œ๋‹ค.
– I์™€ H๋Š” ์ƒํ˜ธ ์ •๋ณด(mutual information)์™€ Shannon entropy์˜ ๊ฐœ๋…์—์„œ ์ฐฉ์•ˆํ•˜์˜€์œผ๋ฉฐ, ๋‘˜ ๋‹ค base e๋กœ ๊ณ„์‚ฐ๋œ๋‹ค.
– S์™€ A๋Š” ๊ฐ๊ฐ States, Actions์— ๋Œ€ํ•œ ์ž„์˜์˜ ๋ณ€์ˆ˜์ด๋‹ค.
– Z ~ p(z)๋Š” latent ๋ณ€์ˆ˜์ด๋‹ค.
– ๋ณธ ๋…ผ๋ฌธ์˜ objective๋Š”, skill์ด agent๊ฐ€ visitํ•˜๋Š” state๋ฅผ ์ œ์–ดํ•ด์•ผ๋งŒ ํ•˜๋Š” ์•„์ด๋””์–ด๋ฅผ ์ธ์ฝ”๋”ฉํ•˜๊ธฐ ์œ„ํ•ด skills์™€ states๊ฐ„์˜ ์ƒํ˜ธ ์ •๋ณด( I(A;Z) )๋ฅผ ์ตœ๋Œ€ํ™”ํ•œ๋‹ค.
– ํŽธ๋ฆฌํ•˜๊ฒŒ๋„, ์ด ์ƒํ˜ธ ์ •๋ณด๋Š” agent๊ฐ€ visitํ•œ state์—์„œ skill์„ ์ถ”๋ก ํ•  ์ˆ˜ ์žˆ์Œ์„ ๋‚˜ํƒ€๋‚ธ๋‹ค.
– actions์ด ์•„๋‹Œ states๊ฐ€ skills๋ฅผ ๊ตฌ๋ถ„ํ•˜๋Š”๋ฐ ์‚ฌ์šฉ๋˜๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด, ์ฃผ์–ด์ง„ state์˜ skill๊ณผ action๊ฐ„์˜ ์ƒํ˜ธ ์ •๋ณด( I(A;Z | S) )๋ฅผ ์ตœ์†Œํ™”ํ•œ๋‹ค.

Policies์˜ mixture๋กœ p(z)์™€ ํ•จ๊ป˜ ๋ชจ๋“  skills๋ฅผ ๋ณผ ๋•Œ ์ด mixture policy์˜ ์—”ํŠธ๋กœํ”ผ H[A | S]๋ฅผ ์ตœ๋Œ€ํ™”ํ•œ๋‹ค.

์š”์•ฝํ•˜์ž๋ฉด, ์•„๋ž˜์˜ g(theta)๋ฅผ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ๊ฒƒ์ด๋‹ค.

Implementation

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” SAC(Soft Actor Critic)์„ ์ด์šฉํ•˜์˜€๊ณ , latent ๋ณ€์ˆ˜ z๋ฅผ ์กฐ๊ฑด์œผ๋กœ ํ•˜๋Š” ์ •์ฑ…์„ ํ•™์Šตํ•œ๋‹ค. 
SAC๋Š” actions์— ๋Œ€ํ•œ policyโ€™s entropy๋ฅผ ์ตœ๋Œ€ํ™”ํ•˜์—ฌ object g์—์„œ ์—”ํŠธ๋กœํ”ผ ํ…€์„ ์ฒ˜๋ฆฌํ•œ๋‹ค.
task reward๋ฅผ ๋Œ€์ฒดํ•˜๋Š” ๋ฐฉ๋ฒ•์œผ๋กœ skill reward๋กœ๋Š” g์˜ expectation์„ ์ด์šฉํ•œ๋‹ค.
– ์œ„ ๊ทธ๋ฆผ์—์„œ logq – logp์— ํ•ด๋‹นํ•˜๋Š” ๋ถ€๋ถ„
– p(z) ์œ„ํ•ด categorical distribution ์ด์šฉํ•˜์˜€๋‹ค.

Unsupervised learning๋™์•ˆ ๊ฐ ์—ํ”ผ์†Œ๋“œ์˜ ์‹œ์ž‘๋ถ€๋ถ„์—์„œ skill z ~ p(z)๋ฅผ ์ƒ˜ํ”Œ๋งํ•˜๊ณ  ์—ํ”ผ์†Œ๋“œ๋‚ด๋‚ด ํ•ด๋‹น skill์„ ๋”ฐ๋ผ ํ–‰๋™ํ•œ๋‹ค.
Agent๋Š” ๊ตฌ๋ถ„ํ•˜๊ธฐ ์‰ฌ์šด states๋ฅผ ๋ฐฉ๋ฌธํ•˜๋ฉด ๋ณด์ƒ์„ ๋ฐ›๋Š” ๋ฐ˜๋ฉด์—, discriminator๋Š” ๋ฐฉ๋ฌธํ•œ states๋กœ๋ถ€ํ„ฐ skill z๋ฅผ ๋” ์ž˜ ์ถ”๋ก ํ•˜๊ธฐ ์œ„ํ•˜์—ฌ ์—…๋ฐ์ดํŠธ ๋œ๋‹ค.
๊ทธ๋ฆฌ๊ณ  SAC ์—…๋ฐ์ดํŠธ์˜ ์ผ๋ถ€๋กœ์„œ ์—”ํŠธ๋กœํ”ผ ์ •๊ทœํ™”๊ฐ€ ์ง„ํ–‰๋œ๋‹ค.
[์ถ”๊ฐ€] ๋งŒ์•ฝ ๋ชจ๋“  tasks์™€ ์—ฐ๊ด€๋˜์–ด ์žˆ๋Š” ์ฃผ์–ด์ง„ latent variable z~p(z)๊ฐ€ ์žˆ๋‹ค๊ณ  ๊ฐ€์ •ํ–ˆ์„ ๋•Œ, ํ•ด๋‹น ๊ฐ’์„ ์ด์šฉํ•˜์—ฌ reward function์„ ์ •์˜ํ•  ์ˆ˜ ์žˆ๋‹ค.
reward function = logD(z|s), ์—ฌ๊ธฐ์„œ D๋Š” discriminator ํ•จ์ˆ˜์ธ๋ฐ state๋กœ๋ถ€ํ„ฐ latent variable์„ ๋ฝ‘๋Š”๋ฐ ์‚ฌ์šฉํ•œ๋‹ค.
D(z|s) = q(z|s)

[์ถ”๊ฐ€] ํ•œ๋ฒˆ discriminator function์ด ํ•™์Šต๋˜๋ฉด, training์„ ์œ„ํ•œ ์ƒˆ๋กœ์šด MDP๋ฅผ samplingํ•˜๋Š” ๊ฒƒ์€ ๋งค์šฐ ์ง๊ด€์ ์ด ๋œ๋‹ค. ์šฐ์„  latent variable z~p(z)๋ฅผ ์ƒ˜ํ”Œ๋งํ•˜๊ณ  ์ด๋ฅผ ์ด์šฉํ•ด์„œ reward function r(s) =logD(z|s)๋ฅผ ๋งŒ๋“ค์–ด๋‚ธ๋‹ค.

(์œ„ ์ถ”๊ฐ€๋‚ด์šฉ ์ถœ์ฒ˜ : talkingaboutme.tistory.com/entry/RL-Meta-Reinforcement-Learning)

Conclusion

๋ณธ ๋…ผ๋ฌธ์—์„œ๋Š” reward function์—†์ด skills๋ฅผ ํ•™์Šตํ•˜๋Š” ๋ฐฉ๋ฒ•์œผ๋กœ DIAYN์„ ์ œ์•ˆํ•˜๊ณ  ์žˆ๋‹ค.
DIAYN์€ ๋ณต์žกํ•œ tasks์— ๋Œ€ํ•ด ๋‹ค์–‘ํ•œ skills๋ฅผ ํ•™์Šตํ•˜๊ณ , ์‹ค์ œ๋กœ task reward๋ฅผ ๋ฐ›๋Š” ๊ณผ์ •์—†์ด bench mark tasks๋ฅผ ํ•ด๊ฒฐํ•˜์˜€๋‹ค.

1) ๋น ๋ฅด๊ฒŒ ์ƒˆ๋กœ์šด task์— ์ ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•
2) HRL์„ ์ด์šฉํ•˜์—ฌ ๋ณต์žกํ•œ tasks๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๋ฐฉ๋ฒ•
3) ์ „๋ฌธ๊ฐ€๋ฅผ ๋ชจ๋ฐฉํ•˜์—ฌ ํ•™์Šตํ•˜๋Š” ๋ฐฉ๋ฒ•
์œ„ 3๊ฐ€์ง€ ๋ฐฉ๋ฒ•๋“ค์„ ์ œ์•ˆํ•˜๊ณ  ์žˆ๋‹ค.
์ผ๋ฐ˜์ ์œผ๋กœ DIAYN์€ task์˜ ๋ณต์žกํ•œ action space๋ฅผ useful skills์˜ set์œผ๋กœ ๋Œ€์ฒดํ•˜์—ฌ task๋ฅผ ๋” ์‰ฝ๊ฒŒ ํ•™์Šตํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•œ๋‹ค. DIAYN์€ observation space์™€ reward function์„ ์ฆ๊ฐ€์‹œํ‚ค๊ธฐ ์œ„ํ•œ ๋ฐฉ๋ฒ•๋“ค๊ณผ ๊ฒฐํ•ฉ๋  ์ˆ˜ ์žˆ๋‹ค.
๊ฒฐ๋ก ์ ์œผ๋กœ task reward function์„ skill reward function์œผ๋กœ ๋Œ€์ฒดํ•˜์—ฌ ์‚ฌ์šฉํ•œ๋‹ค๋Š” ๊ฒƒ์ด๋‹ค.

CODE

์ฝ”๋“œ ๊ตฌํ˜„ ๋‚ด์—ญ์— ๋Œ€ํ•˜์—ฌ ์„ค๋ช…ํ•ด ๋ณด๋„๋ก ํ•˜๊ฒ ๋‹ค
๊ธฐ๋ณธ์ ์ธ ๊ณจ๊ฒฉ์€ SAC๋ฅผ ์ฐธ๊ณ ํ•˜๋ฉด ๋˜๋Š” ๋ถ€๋ถ„์ธ์ง€๋ผ SAC ์ฝ”๋“œ๋ฅผ ๋จผ์ € ์‚ดํŽด๋ณธ ํ›„ DIAYN์—์„œ ๊ฐœ์„ ๋œ ์ฝ”๋“œ ๋‚ด์—ญ์„ ์‚ดํŽด๋ณด๋Š” ๊ฒƒ์œผ๋กœ ์ง„ํ–‰ํ•˜๋„๋ก ํ•˜๊ฒ ๋‹ค.

SAC CODE

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions import Normal
import numpy as np
import random
import matplotlib.pyplot as plt

#===============================================================================
# global variables
#===============================================================================
seed = 1
gamma = 0.99
tau = 0.005
alpha = 0.2
lr = 0.0003
hidden_size = 256
epsilon = 1e-6
replay_size = 1000000
start_steps = 10000
updates_per_step = 1
batch_size = 256
num_steps = 1000000

def weights_init_(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight, gain=1)
torch.nn.init.constant_(m.bias, 0)

def hard_update(target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)

def soft_update(target, source, tau):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)

class ReplayMemory:
def __init__(self, capacity, seed):
random.seed(seed)
self.capacity = capacity
self.buffer = []
self.position = 0

def push(self, state, action, reward, next_state, done):
if len(self.buffer) < self.capacity:
self.buffer.append(None)
self.buffer[self.position] = (state, action, reward, next_state, done)
self.position = (self.position + 1) % self.capacity

def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
state, action, reward, next_state, done = map(np.stack, zip(*batch))
return state, action, reward, next_state, done

def __len__(self):
return len(self.buffer)

class QNetwork(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_dim):
super(QNetwork, self).__init__()

# Q1 architecture
self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, 1)

# Q2 architecture
self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
self.linear5 = nn.Linear(hidden_dim, hidden_dim)
self.linear6 = nn.Linear(hidden_dim, 1)

self.apply(weights_init_)

def forward(self, state, action):
xu = torch.cat([state, action], 1)

x1 = F.relu(self.linear1(xu))
x1 = F.relu(self.linear2(x1))
x1 = self.linear3(x1)

x2 = F.relu(self.linear4(xu))
x2 = F.relu(self.linear5(x2))
x2 = self.linear6(x2)

return x1, x2

class ValueNetwork(nn.Module):
def __init__(self, num_inputs, hidden_dim):
super(ValueNetwork, self).__init__()

self.linear1 = nn.Linear(num_inputs, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, 1)

self.apply(weights_init_)

def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x

class GaussianPolicy(nn.Module):
def __init__(self, num_inputs, num_actions, hidden_dim):
super(GaussianPolicy, self).__init__()

self.linear1 = nn.Linear(num_inputs, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)

self.mean_linear = nn.Linear(hidden_dim, num_actions)
self.log_std_linear = nn.Linear(hidden_dim, num_actions)

self.apply(weights_init_)

def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
mean = self.mean_linear(x)
log_std = self.log_std_linear(x)
log_std = torch.clamp(log_std, min=-20, max=2)
return mean, log_std

def sample(self, state):
mean, log_std = self.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
action = torch.tanh(x_t)
log_prob = normal.log_prob(x_t)
# Enforcing Action Bound
log_prob -= torch.log(1 - action.pow(2) + epsilon)
log_prob = log_prob.sum(1, keepdim=True)
return action, log_prob, mean, log_std

class SAC(object):
def __init__(self, num_inputs, action_space):
self.gamma = gamma
self.tau = tau
self.alpha = alpha
self.action_range = [action_space.low, action_space.high]
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

self.critic = QNetwork(num_inputs, action_space.shape[0], hidden_size).to(device=self.device)
self.critic_optim = Adam(self.critic.parameters(), lr=lr)

self.value = ValueNetwork(num_inputs, hidden_size).to(device=self.device)
self.value_target = ValueNetwork(num_inputs, hidden_size).to(self.device)
self.value_optim = Adam(self.value.parameters(), lr=lr)
hard_update(self.value_target, self.value)

self.policy = GaussianPolicy(num_inputs, action_space.shape[0], hidden_size).to(self.device)
self.policy_optim = Adam(self.policy.parameters(), lr=lr)

def select_action(self, state):
state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
action, _, _, _ = self.policy.sample(state)
action = action.detach().cpu().numpy()[0]
return self.rescale_action(action)

def rescale_action(self, action):
return action * (self.action_range[1] - self.action_range[0]) / 2.0 +\
(self.action_range[1] + self.action_range[0]) / 2.0

def update_parameters(self, memory, batch_size, updates):
# Sample a batch from memory
state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)

state_batch = torch.FloatTensor(state_batch).to(self.device)
next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
action_batch = torch.FloatTensor(action_batch).to(self.device)
reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

with torch.no_grad():
vf_next_target = self.value_target(next_state_batch)
next_q_value = reward_batch + mask_batch * self.gamma * (vf_next_target)

qf1, qf2 = self.critic(state_batch, action_batch)
qf1_loss = F.mse_loss(qf1, next_q_value)
qf2_loss = F.mse_loss(qf2, next_q_value)
qf_loss = qf1_loss + qf2_loss

self.critic_optim.zero_grad()
qf_loss.backward()
self.critic_optim.step()

pi, log_pi, mean, log_std = self.policy.sample(state_batch)

qf1_pi, qf2_pi = self.critic(state_batch, pi)
min_qf_pi = torch.min(qf1_pi, qf2_pi)

policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

# Regularization Loss (optional)
reg_loss = 0.001 * (mean.pow(2).mean() + log_std.pow(2).mean())
policy_loss += reg_loss

self.policy_optim.zero_grad()
policy_loss.backward()
self.policy_optim.step()

vf = self.value(state_batch)

with torch.no_grad():
vf_target = min_qf_pi - (self.alpha * log_pi)

vf_loss = F.mse_loss(vf, vf_target)

self.value_optim.zero_grad()
vf_loss.backward()
self.value_optim.step()

soft_update(self.value_target, self.value, self.tau)

return vf_loss.item(), qf1_loss.item(), qf2_loss.item(), policy_loss.item()

def main():
env = gym.make('Pendulum-v1')

env.seed(seed)
env.action_space.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

agent = SAC(env.observation_space.shape[0], env.action_space)
memory = ReplayMemory(replay_size, seed)

# Training Loop
total_numsteps = 0
updates = 0
ep_r_store = []

for i_episode in range(1000):
episode_reward = 0
episode_steps = 0
done = False
state = env.reset()

while not done:
if start_steps > total_numsteps:
action = env.action_space.sample()
else:
action = agent.select_action(state) # Sample action from policy

if len(memory) > batch_size:
for i in range(updates_per_step): # Number of updates per step in environment
# Update parameters of all the networks
value_loss, critic_1_loss, critic_2_loss, policy_loss = agent.update_parameters(memory, batch_size, updates)
updates += 1

next_state, reward, done, _ = env.step(action) # Step
episode_steps += 1
total_numsteps += 1
episode_reward += reward

# Ignore the "done" signal if it comes from hitting the time horizon.
# (https://github.com/openai/spinningup/blob/master/spinup/algos/sac/sac.py)
mask = 1 if episode_steps == env._max_episode_steps else float(not done)

memory.push(state, action, reward, next_state, mask) # Append transition to memory

state = next_state

if done:
ep_r_store.append(episode_reward)

if total_numsteps > num_steps:
break

print("Episode: {}, total numsteps: {}, episode steps: {}, reward: {}".format(
i_episode, total_numsteps, episode_steps, round(episode_reward, 2)))

env.close()

plt.plot(ep_r_store)
plt.title('SAC')
plt.xlabel('episode number')
plt.ylabel('return')
plt.grid(True)
plt.savefig("sac.png")


if __name__ == '__main__':
main()

hard_update์™€ soft_update์˜ ์ฐจ์ด๋ฅผ ์ž˜ ์•Œ์•„์•ผ ํ•œ๋‹ค. ์ด๋ฅผ ์ดํ•ดํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” full batch์™€ mini batch์ฐจ์ด๋ฅผ ์•„๋Š” ๊ฒƒ์ด ๋„์›€์ด ๋œ๋‹ค. hard_update๋ฅผ full batch๋กœ soft_update๋ฅผ mini batch๋กœ ์ดํ•ดํ•˜๋ฉด ๋œ๋‹ค.

replay buffer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ Environment์—์„œ Policy์„ ์‹คํ–‰ํ•  ๋•Œ ๊ฒฝํ—˜์˜ ๊ถค์ ์„ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. ํ›ˆ๋ จ ์ค‘์—, ์—์ด์ „ํŠธ์˜ ๊ฒฝํ—˜์„ “replay”ํ•˜๊ธฐ ์œ„ํ•ด ๊ถค์ ์˜ ์„œ๋ธŒ ์„ธํŠธ(์ˆœ์ฐจ ์„œ๋ธŒ ์„ธํŠธ ๋˜๋Š” ์ƒ˜ํ”Œ)์— ๋Œ€ํ•ด ์žฌํ˜„ ๋ฒ„ํผ๊ฐ€ ์กฐํšŒ๋จ

Value Network๋ฅผ Q-Network๋กœ ๊ตฌํ˜„ํ•œ ๋‚ด์—ญ์ด๋‹ค.

Value Network์˜ ๊ตฌํ˜„์ฒด๋Š” ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

Actor network
– SAC๋‚ด์—์„œ continuous action space๋‚ด์— stochastic policy๋ฅผ ์‚ฌ์šฉํ•œ๋‹ค.
– actor network๋Š” ๊ฐ๊ฐ์˜ action์— ๋Œ€ํ•ด ํ‰๊ท ๊ณผ ํ‘œ์ค€ํŽธ์ฐจ๋ฅผ ์ƒ์„ฑํ•œ๋‹ค.
– action์€ ์ฃผ์–ด์ง„ ํ‰๊ท ๊ณผ ํ‘œ์ค€ํŽธ์ฐจ๋ฅผ ๊ฐ€์ง€๊ณ  gaussian distribution์œผ๋กœ๋ถ€ํ„ฐ sample๋˜์–ด์ง„ ๊ฒƒ์œผ๋กœ ๋ณธ๋‹ค.
– ๊ณ„์‚ฐ๋œ ํ‘œ์ค€ํŽธ์ฐจ ๋Œ€์‹ ์—, ์šฐ๋ฆฌ๋Š” network ์—ฐ์‚ฐ๋œ ๋กœ๊ทธ ํŽธ์ฐจ์™€ ๊ทธ๋ฆฌ๊ณ  ์ด๋ฅผ ์ถ”ํ›„์— ํ‘œ์ค€ํŽธ์ฐจ๋กœ ๋ณ€ํ™˜ํ•œ๋‹ค

Actor network
– ์ƒํƒœ๊ฐ€ ์ฃผ์–ด์ง€๊ณ , actor network๋Š” ํ‰๊ท ๊ณผ log_std๋ฅผ ์ƒ์„ฑํ•œ๋‹ค.
– log_std๋Š” ๋งค์šฐ ํฌ๊ณ  ๋งค์šฐ ์ž‘์€ std๋ฅผ ์ƒ์„ฑํ•˜์ง€ ์•Š๋„๋ก ๊ณ ์ •ํ•ด ๋‘”๋‹ค.

SAC agent

DIAYN CODE

https://github.com/alirezakazemipour/DIAYN-PyTorch

https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/blob/master/agents/hierarchical_agents/DIAYN.py

Tagged ,

Leave a Reply

Your email address will not be published. Required fields are marked *