Ê×Ò³
ѧϰ
»î¶¯
רÇø
¹¤¾ß
TVP
·¢²¼
¾«Ñ¡ÄÚÈÝ/¼¼ÊõÉçȺ/ÓŻݲúÆ·,¾¡ÔÚС³ÌÐò
Á¢¼´Ç°Íù

3ÐдúÂëÌáËÙÄ£ÐÍѵÁ·£ºÕâ¸öËã·¨ÈÃÄãµÄGPUÀÏÊ÷¿ªÐ»¨

°Ù¶ÈºÍNvidiaÑо¿Ôº½áºÏN¿¨µ×²ã¼ÆËãÓÅ»¯£¬Ìá³öÁËÒ»ÖÖÓÐЧµÄÉñ¾­ÍøÂçѵÁ·¼ÓËÙ·½·¨£¬²»½öÊÇԤѵÁ·£¬ÔÚÈ«Ãñfinetune BERTµÄ½ñÌì±äµÃÒì³£ÓÐÓá£

Ò»Çл¹Òª´Ó2018ÄêICLRµÄһƪÂÛÎÄ˵Æð¡£

¡¶MIXED PRECISION TRAINING¡·ÊÇ°Ù¶È&NvidiaÑо¿ÔºÒ»Æð·¢±íµÄ£¬½áºÏN¿¨µ×²ã¼ÆËãÓÅ»¯£¬Ìá³öÁËÒ»ÖÖ»Ò³£ÓÐЧµÄÉñ¾­ÍøÂçѵÁ·¼ÓËÙ·½·¨£¬²»½öÊÇԤѵÁ·£¬ÔÚÈ«Ãñfinetune BERTµÄ½ñÌì±äµÃÒì³£ÓÐÓÃÍÛ¡£

¶øÇÒµ÷Ñз¢ÏÖ£¬²»½ö°Ù¶ÈµÄpaddle¿ò¼ÜÖ§³Ö»ìºÏ¾«¶ÈѵÁ·£¬ÔÚTensorflowºÍPytorchÖÐÒ²ÓÐÏàÓ¦µÄʵÏÖ¡£ÏÂÃæÎÒÃÇÏÈÀ´½²½²ÀíÂÛ£¬ºóÃæÔÙ·ÖÎö»ìºÏ¾«¶ÈѵÁ·ÔÚÈý´óÉî¶Èѧϰ¿ò¼ÜÖеĴò¿ª·½Ê½¡£

ÀíÂÛÔ­Àí

ѵÁ·¹ýÉñ¾­ÍøÂçµÄС»ï°é¶¼ÖªµÀ£¬Éñ¾­ÍøÂçµÄ²ÎÊýºÍÖмä½á¹û¾ø´ó²¿·Ö¶¼Êǵ¥¾«¶È¸¡µãÊý£¨¼´float32£©´æ´¢ºÍ¼ÆËãµÄ£¬µ±ÍøÂç±äµÃ³¬¼¶´óʱ£¬½µµÍ¸¡µãÊý¾«¶È£¬±ÈÈçʹÓð뾫¶È¸¡µãÊý£¬ÏÔÈ»ÊÇÌá¸ß¼ÆËãËٶȣ¬½µµÍ´æ´¢¿ªÏúµÄÒ»¸öºÜÖ±½ÓµÄ°ì·¨¡£

È»¶ø¸±×÷ÓÃÒ²ºÜÏÔÈ»£¬Èç¹ûÎÒÃÇÖ±½Ó½µµÍ¸¡µãÊýµÄ¾«¶ÈÖ±¹ÛÉϱØÈ»µ¼ÖÂÄ£ÐÍѵÁ·¾«¶ÈµÄËðʧ¡£µ«ÊÇÄØ£¬ÌìÍâÓÐÌ죬ÕâƪÎÄÕÂÓÃÁËÈýÖÖ»úÖÆÓÐЧµØ·ÀÖ¹ÁËÄ£Ð͵ľ«¶ÈËðʧ¡£´ýСϦһһ˵À´o(*£þ¨Œ£þ*)¥Ö

ȨÖر¸·Ý(master weights)

ÎÒÃÇÖªµÀ°ë¾«¶È¸¡µãÊý£¨float16£©ÔÚ¼ÆËã»úÖеıíʾ·ÖΪ1bitµÄ·ûºÅ룬5bitsµÄÖ¸ÊýλºÍ10bitsµÄβÊý룬ËùÒÔËüÄܱíʾµÄ×îСµÄÕýÊý¼´2^-24£¨Ò²¾ÍÊǾ«¶Èµ½´ËΪֹÁË£©¡£µ±Éñ¾­ÍøÂçÖеÄÌݶȻҳ£Ð¡µÄʱºò£¬ÍøÂçѵÁ·¹ý³ÌÖÐÿһ²½µÄµü´ú£¨»Ò³£Ð¡µÄÌÝ¶È ? Ò²ºÚСµÄlearning rate£©»á±äµÃ¸üС£¬Ð¡µ½float16¾«¶ÈÎÞ·¨±íʾµÄʱºò£¬ÏàÓ¦µÄÌݶȾÍÎÞ·¨µÃµ½¸üС£

ÂÛÎÄͳ¼ÆÁËÒ»ÏÂÔÚMandarinÊý¾Ý¼¯ÉÏѵÁ·DeepSpeech 2Ä£ÐÍʱ²úÉú¹ýµÄÌݶȣ¬·¢ÏÖÔÚδ³ËÒÔlearning rate֮ǰ£¬¾ÍÓнӽü5%µÄÌݶÈÖ±½Ó±¯¾çµÄ±ä³É0£¨¾«¶È±È2^-24»¹Òª¸ßµÄÌݶȻáÖ±½Ó±ä³É0£©£¬Ôì³ÉÖØ´óµÄËðʧѽ/(¨Òo¨Ò)/~~

»¹ÓиüÄѵģ¬¼ÙÉèµü´úÁ¿ÌÓ¹ýÒ»½Ù×¼±¸·îÏ××Ô¼ºµÄʱºò¡£¡£¡£ÓÉÓÚÍøÂçÖеÄȨÖØÍùÍùÔ¶´óÓÚÎÒÃÇÒª¸üеÄÁ¿£¬µ±µü´úÁ¿Ð¡ÓÚFloat16µ±Ç°Çø¼äÄÚÄܱíʾµÄ×îС¼ä¸ôµÄʱºò£¬¸üÐÂÒ²»áʧ°Ü£¨¿ÞϹ©Ñ©Ò©n©Ñ©ÒÎÒÔõôÕâôÄÑѼ£©

ËùÒÔÔõô°ìÄØ£¿×÷ÕßÕâÀïÌá³öÁËÒ»¸ö·Ç³£simple but effectiveµÄ·½·¨£¬¾ÍÊÇÇ°Ïò´«²¥ºÍÌݶȼÆË㶼ÓÃfloat16£¬µ«ÊÇ´æ´¢ÍøÂç²ÎÊýµÄÌݶÈʱҪÓÃfloat32£¡ÕâÑù¾Í¿ÉÒÔÒ»¶¨³Ì¶ÈÉϵĽâ¾öÉÏÃæ˵µÄÁ½¸öÎÊÌâÀ²~~~

ÎÒÃÇÀ´¿´Ò»ÏÂѵÁ·ÇúÏߣ¬À¶É«µÄÏßÊÇÕý³£µÄfloat32¾«¶ÈѵÁ·ÇúÏߣ¬³ÈÉ«µÄÏßÊÇʹÓÃfloat32´æ´¢ÍøÂç²ÎÊýµÄlearning curve£¬ÂÌÉ«µÎÊDz»Ê¹ÓÃfloat32´æ´¢²ÎÊýµÄÇúÏߣ¬Á½ÕßÒ»±È¾ÍÏàÐμûç©À²¡£

Ëðʧ·ÅËõ£¨loss scaling£©

ÓÐÁËÉÏÃæµÄmaster weightsÒѾ­¿ÉÒÔ×ã¹»¸ß¾«¶ÈµÄѵÁ·ºÜ¶àÍøÂçÀ²£¬µ«ÊÇÓеãÇ¿ÆÈÖ¢µÄСϦÀ´ËµÔõô»¹ÊǾõµÃÓе㲻¶Ôѽo((¡Ñ©n¡Ñ))o.

ËäȻʹÓÃfloat32À´´æ´¢Ìݶȣ¬È·Êµ²»»á¶ªÊ§¾«¶ÈÁË£¬µ«ÊǼÆËã¹ý³ÌÖгöÏÖµÄÖ¸ÊýλСÓÚ -24 µÄÌݶȲ»»¹ÊǻᶪʧµÄÂÏ൱ÓÚÓéˮµÄɸ×Ó´ÓºÓ±ßÍù´åÀïÔËË®£¬ÎªÁ˶à´æµãË®£¬´åÃñÃÇ°Ñ´¢Ë®µÄÍë»»³ÉÁË´ó¸×£¬È¼¶ìɸ×ÓÒÀÈ»ÊÇ©µÄÍÛ£¬ÔÚ·ÉϵÄʱºòË®¾ÍÒѾ­Â©µÄľÓÐÁË¡£¡£

ÓÚÊÇloss scaling·½·¨À´ÁË¡£Ê×ÏÈ×÷Õßͳ¼ÆÁËÒ»ÏÂѵÁ·¹ý³ÌÖм¤»îº¯ÊýÌݶȵķֲ¼Çé¿ö£¬ÓÉÓÚÍøÂçÖеÄÌݶÈÍùÍù¶¼·Ç³£Ð¡£¬µ¼ÖÂÔÚʹÓÃFP16µÄʱºòÓÒ±ßÓдóÁ¿µÄ·¶Î§ÊÇûÓÐʹÓõġ£ÕâÖÖÇé¿öÏ£¬ ÎÒÃÇ¿ÉÒÔͨ¹ý·Å´ólossÀ´°ÑÕû¸öÌݶÈÓÒÒÆ£¬¼õÉÙÒòΪ¾«¶ÈËæʱ±äΪ0µÄÌݶȡ£

ÄÇôÎÊÌâÀ´ÁË£¬ÔõôºÏÀíµÄ·Å´ólossÄØ£¿Ò»¸ö×î¼òµ¥µÄ·½·¨Êdz£ÊýËõ·Å£¬°ÑlossÒ»¹ÉÄÔͳһ·Å´óS±¶¡£float16ÄܱíʾµÄ×î´óÕýÊýÊÇ2^15*(1+1-2^-10)=65504£¬ÎÒÃÇ¿ÉÒÔͳ¼ÆÍøÂçÖеÄÌݶȣ¬¼ÆËã³öÒ»¸ö³£ÊýS£¬Ê¹µÃ×î´óµÄÌݶȲ»³¬¹ýfloat16ÄܱíʾµÄ×î´óÕûÊý¼´¿É¡£

µ±È»À²£¬»¹Óиü¼ÓÖÇÄܵĶ¯Ì¬µ÷Õû(automatic scaling)?o(*£þ¨Œ£þ*)¥Ö

ÎÒÃÇÏȳõʼ»¯Ò»¸öºÜ´óµÄS£¬Èç¹ûÌݶÈÒç³ö£¬ÎÒÃǾͰÑSËõСΪԭÀ´µÄ¶þ·ÖÖ®Ò»£»Èç¹ûÔںܶà´Îµü´úÖÐÌݶȶ¼Ã»ÓÐÒç³ö£¬ÎÒÃÇÒ²¿ÉÒÔ³¢ÊÔ°ÑS·Å´óÁ½±¶¡£ÒÔ´ËÀàÍÆ£¬ÊµÏÖ¶¯Ì¬µÄloss scaling¡£

ÔËË㾫¶È£¨precison of ops£©

¾«ÒæÇó¾«ÔÙ½øÒ»²½£¬Éñ¾­ÍøÂçÖеÄÔËËãÖ÷Òª¿ÉÒÔ·ÖΪËÄ´óÀ࣬»ìºÏ¾«¶ÈѵÁ·°ÑһЩÓиü¸ß¾«¶ÈÒªÇóµÄÔËË㣬ÔÚ¼ÆËã¹ý³ÌÖÐʹÓÃfloat32£¬´æ´¢µÄʱºòÔÙת»»Îªfloat16¡£

  • matrix multiplication:?linear, matmul, bmm, conv
  • pointwise:?relu, sigmoid, tanh, exp, log
  • reductions:?batch norm, layer norm, sum, softmax
  • loss functions:?cross entropy, l2 loss, weight decay

Ïñ¾ØÕó³Ë·¨ºÍ¾ø´ó¶àÊýpointwiseµÄ¼ÆËã¿ÉÒÔÖ±½ÓʹÓÃfloat16À´¼ÆËã²¢´æ´¢£¬¶øreductions¡¢loss functionºÍһЩpointwise£¨Èçexp£¬log£¬powµÈº¯ÊýÖµÔ¶´óÓÚ±äÁ¿µÄº¯Êý£©ÐèÒª¸ü¼Ó¾«Ï¸µÄ´¦Àí£¬ËùÒÔÔÚ¼ÆËãÖÐʹÓÃÓÃfloat32£¬ÔÙ½«½á¹ûת»»Îªfloat16À´´æ´¢¡£

×ܽ᣺Èý´óÉî¶Èѧϰ¿ò¼ÜµÄ´ò¿ª·½Ê½

»ìºÏ¾«¶ÈѵÁ·×öµ½ÁËÔÚÇ°ÏòºÍºóÏò¼ÆËã¹ý³ÌÖоùʹÓð뾫¶È¸¡µãÊý£¬²¢ÇÒûÓÐÏñ֮ǰµÄһЩ¹¤×÷Ò»Ñù»¹ÒýÈë¶îÍⳬ²Î£¬¶øÇÒÖØÒªµÄÊÇ£¬ÊµÏַdz£¼òµ¥È´ÄÜ´øÀ´·Ç³£ÏÔÖøµÄÊÕÒ棬ÔÚÏÔ´æhalfÒÔ¼°ËÙ¶ÈdoubleµÄÇé¿öϱ£³ÖÄ£Ð͵ľ«¶È£¬¼òÖ±²»ÄÜÔÙÀ÷º¦À²¡£

¿´ÍêÁËÓ²ºË¼¼Êõϸ½ÚÖ®ºó£¬ÎÒÃǸϽôÀ´¿´¿´´úÂëʵÏÖ°É£¡Èç´ËÇ¿´óµÄ»ìºÏ¾«¶ÈѵÁ·µÄ´úÂëʵÏÖ²»ÒªÌ«¼òµ¥ÁË°É?

Pytorch

µ¼ÈëAutomatic Mixed Precision (AMP)£¬²»Òª998²»Òª288£¬Ö»Ðè3ÐÐÎÞʹʹÓã¡

´úÂëÓïÑÔ£ºjavascript
¸´ÖÆ
from?apex?import?ampmodel,?optimizer?=?amp.initialize(model,?optimizer,?opt_level="O1")?#?ÕâÀïÊÇ¡°Å·Ò»¡±£¬²»ÊÇ¡°ÁãÒ»¡±with?amp.scale_loss(loss,?optimizer)?as?scaled_loss:scaled_loss.backward()?

À´¿´¸öÀý×Ó£¬½«ÉÏÃæÈýÐа´ÕÕÕýÈ·µÄλÖòåÈëµ½×Ô¼ºÔ­À´µÄ´úÂëÖоͿÉÒÔʵÏÖ¿áìŵİ뾫¶ÈѵÁ·À²£¡

´úÂëÓïÑÔ£ºjavascript
¸´ÖÆ
import?torchfrom?apex?import?ampmodel?=?...?optimizer?=?...#°ü×°modelºÍoptimizermodel,?optimizer?=?amp.initialize(model,?optimizer,?opt_level="O1")for?data,?label?in?data_iter:?out?=?model(data)?loss?=?criterion(out,?label)?optimizer.zero_grad()?#loss?scaling£¬´úÌæloss.backward()?with?amp.scaled_loss(loss,?optimizer)?as?scaled_loss:scaled_loss.backward()?optimizer.step()?

Tensorflow

Ò»¾ä»°ÊµÏÖ»ìºÏ¾«¶ÈѵÁ·Ö®Ð޸Ļ·¾³±äÁ¿£¬ÔÚpython½Å±¾ÖÐÉèÖû·¾³±äÁ¿

´úÂëÓïÑÔ£ºjavascript
¸´ÖÆ
os.environ[?TF_ENABLE_AUTO_MIXED_PRECISION?]?=?1?

³ý´ËÖ®Í⣬Ҳ¿ÉÒÔÓÃÀàËÆpytorchµÄ·½Ê½À´°ü×°optimizer¡£

Graph-basedʾÀý

´úÂëÓïÑÔ£ºjavascript
¸´ÖÆ
opt?=?tf.train.AdamOptimizer()#add?a?lineopt?=?tf.train.experimental.enable_mixed_precision_graph_rewrite(?opt,?loss_scale=?dynamic?)?train_op?=?opt.miminize(loss)?

Keras-basedʾÀý

´úÂëÓïÑÔ£ºjavascript
¸´ÖÆ
opt?=?tf.keras.optimizers.Adam()#add?a?lineopt?=?tf.train.experimental.enable_mixed_precision_graph_rewrite(?opt,?loss_scale=?dynamic?)?model.compile(loss=loss,?optimizer=opt)model.fit(...)?

PaddlePaddle

Ò»¾ä»°ÊµÏÖ»ìºÏ¾«¶ÈѵÁ·Ö®Ìí¼Óconfig£¨¾ª´ô?±Ï¾¹»ìºÏ¾«¶ÈѵÁ·ÊǰٶȼÒÌá³öµÄ£¬ÄÚ²¿Ôç¾ÍÊìÁ·Ó¦ÓÃÁË°È£©

´úÂëÓïÑÔ£ºjavascript
¸´ÖÆ
--use_fp16=true?

¾Ù¸öÀõ×Ó£¬»ùÓÚBERT finetune XNLIÈÎÎñʱ£¬Ö»ÐèÔÚÖ´ÐÐʱÉèÖÃuse_fp16Ϊtrue¼´¿É¡£

´úÂëÓïÑÔ£ºjavascript
¸´ÖÆ
export?FLAGS_sync_nccl_allreduce=0export?FLAGS_eager_delete_tensor_gb=1export?CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7BERT_BASE_PATH="chinese_L-12_H-768_A-12"TASK_NAME=?XNLI?DATA_PATH=/path/to/xnli/data/CKPT_PATH=/path/to/save/checkpoints/python?-u?run_classifier.py?--task_name?${TASK_NAME}?--use_fp16=true?#!!!!!!add?a?line?--use_cuda?true?--do_train?true?--do_val?true?--do_test?true?--batch_size?32?--in_tokens?false?--init_pretraining_params?${BERT_BASE_PATH}/params?--data_dir?${DATA_PATH}?--vocab_path?${BERT_BASE_PATH}/vocab.txt?--checkpoints?${CKPT_PATH}?--save_steps?1000?--weight_decay?0.01?--warmup_proportion?0.1?--validation_steps?100?--epoch?3?--max_seq_len?128?--bert_config_path?${BERT_BASE_PATH}/bert_config.json?--learning_rate?5e-5?--skip_steps?10?--num_iteration_per_drop_scope?10?--verbose?true?
  • ·¢±íÓÚ:
  • Ô­ÎÄÁ´½Ó£ºhttp://news.51cto.com/art/202001/609681.htm
  • ÈçÓÐÇÖȨ£¬ÇëÁªÏµ cloudcommunity@tencent.com ɾ³ý¡£

Ïà¹Ø¿ìѶ

ɨÂë

Ìí¼ÓÕ¾³¤ ½ø½»Á÷Ⱥ

ÁìȡרÊô 10ÔªÎÞÃż÷ȯ

˽Ïí×îР¼¼Êõ¸É»õ

ɨÂë¼ÓÈ뿪·¢ÕßÉçȺ
Áìȯ
http://www.vxiaotou.com