mirror of
https://github.com/labmlai/annotated_deep_learning_paper_implementations.git
synced 2025-11-01 12:01:45 +08:00
{
"<h1><a href=\"https://nn.labml.ai/transformers/mlm/index.html\">Masked Language Model (MLM)</a></h1>\n<p>This is a <a href=\"https://pytorch.org\">PyTorch</a> implementation of Masked Language Model (MLM) used to pre-train the BERT model introduced in the paper <a href=\"https://arxiv.org/abs/1810.04805\">BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding</a>.</p>\n<h2>BERT Pretraining</h2>\n<p>BERT model is a transformer model. The paper pre-trains the model using MLM and with next sentence prediction. We have only implemented MLM here.</p>\n<h3>Next sentence prediction</h3>\n<p>In <em>next sentence prediction</em>, the model is given two sentences <span translate=no>_^_0_^_</span> and <span translate=no>_^_1_^_</span> and the model makes a binary prediction whether <span translate=no>_^_2_^_</span> is the sentence that follows <span translate=no>_^_3_^_</span> in the actual text. The model is fed with actual sentence pairs 50% of the time and random pairs 50% of the time. This classification is done while applying MLM. <em>We haven't implemented this here.</em></p>\n<h2>Masked LM</h2>\n<p>This masks a percentage of tokens at random and trains the model to predict the masked tokens. They <strong>mask 15% of the tokens</strong> by replacing them with a special <span translate=no>_^_4_^_</span> token.</p>\n<p>The loss is computed on predicting the masked tokens only. This causes a problem during fine-tuning and actual usage since there are no <span translate=no>_^_5_^_</span> tokens at that time. Therefore we might not get any meaningful representations.</p>\n<p>To overcome this <strong>10% of the masked tokens are replaced with the original token</strong>, and another <strong>10% of the masked tokens are replaced with a random token</strong>. This trains the model to give representations about the actual token whether or not the input token at that position is a <span translate=no>_^_6_^_</span>. And replacing with a random token causes it to give a representation that has information from the context as well; because it has to use the context to fix randomly replaced tokens.</p>\n<h2>Training</h2>\n<p>MLMs are harder to train than autoregressive models because they have a smaller training signal. i.e. only a small percentage of predictions are trained per sample.</p>\n<p>Another problem is since the model is bidirectional, any token can see any other token. This makes the "credit assignment" harder. Let's say you have the character level model trying to predict <span translate=no>_^_7_^_</span>. At least during the early stages of the training, it'll be super hard to figure out why the replacement for <span translate=no>_^_8_^_</span> should be <span translate=no>_^_9_^_</span>, it could be anything from the whole sentence. Whilst, in an autoregressive setting the model will only have to use <span translate=no>_^_10_^_</span> to predict <span translate=no>_^_11_^_</span> and <span translate=no>_^_12_^_</span> to predict <span translate=no>_^_13_^_</span> and so on. So the model will initially start predicting with a shorter context first and then learn to use longer contexts later. Since MLMs have this problem it's a lot faster to train if you start with a smaller sequence length initially and then use a longer sequence length later.</p>\n<p>Here is <a href=\"https://nn.labml.ai/transformers/mlm/experiment.html\">the training code</a> for a simple MLM model. </p>\n": "<h1><a href=\"https://nn.labml.ai/transformers/mlm/index.html\">\u8499\u9762\u8bed\u8a00\u6a21\u578b (MLM)</a></h1>\n<p>\u8fd9\u662f\u63a9\u7801\u8bed\u8a00\u6a21\u578b (MLM) \u7684 <a href=\"https://pytorch.org\">PyTorch</a> \u5b9e\u73b0\uff0c\u7528\u4e8e\u9884\u8bad\u767d\u6587\u300aBERT<a href=\"https://arxiv.org/abs/1810.04805\">\uff1a\u9884\u8bad\u7ec3\u6df1\u5ea6\u53cc\u5411\u8f6c\u6362\u5668\u4ee5\u4fc3\u8fdb\u8bed\u8a00\u7406\u89e3\u300b\u4e2d\u4ecb\u7ecd\u7684 BER</a> T \u6a21\u578b\u3002</p>\n<h2>BERT \u9884\u8bad\u7ec3</h2>\n<p>BERT \u6a21\u578b\u662f\u53d8\u538b\u5668\u6a21\u578b\u3002\u672c\u6587\u4f7f\u7528 MLM \u548c\u4e0b\u4e00\u53e5\u9884\u6d4b\u5bf9\u6a21\u578b\u8fdb\u884c\u4e86\u9884\u8bad\u7ec3\u3002\u6211\u4eec\u53ea\u5728\u8fd9\u91cc\u5b9e\u65bd\u4e86\u4f20\u9500\u3002</p>\n<h3>\u4e0b\u4e00\u53e5\u9884\u6d4b</h3>\n<p>\u5728<em>\u4e0b\u4e00\u4e2a\u53e5\u5b50\u9884\u6d4b</em>\u4e2d\uff0c\u7ed9\u51fa\u4e24\u4e2a\u53e5\u5b50\uff0c<span translate=no>_^_0_^_</span><span translate=no>_^_1_^_</span>\u7136\u540e\u6a21\u578b\u5bf9\u5b9e\u9645\u6587\u672c<span translate=no>_^_3_^_</span>\u4e2d\u540e\u9762\u7684\u53e5\u5b50\u662f\u5426<span translate=no>_^_2_^_</span>\u662f\u540e\u9762\u7684\u53e5\u5b50\u8fdb\u884c\u4e8c\u8fdb\u5236\u9884\u6d4b\u3002\u8be5\u6a21\u578b\u6709 50% \u7684\u65f6\u95f4\u4e3a\u5b9e\u9645\u53e5\u5b50\u5bf9\uff0c50% \u7684\u65f6\u95f4\u4e3a\u968f\u673a\u53e5\u5bf9\u3002\u8fd9\u79cd\u5206\u7c7b\u662f\u5728\u5e94\u7528\u4f20\u9500\u65f6\u5b8c\u6210\u7684\u3002<em>\u6211\u4eec\u8fd8\u6ca1\u6709\u5728\u8fd9\u91cc\u5b9e\u73b0\u8fd9\u4e00\u70b9\u3002</em></p>\n<h2>Masked LM</h2>\n<p>\u8fd9\u4f1a\u968f\u673a\u63a9\u76d6\u4e00\u5b9a\u6bd4\u4f8b\u7684\u4ee3\u5e01\uff0c\u5e76\u8bad\u7ec3\u6a21\u578b\u9884\u6d4b\u88ab\u63a9\u7801\u7684\u4ee3\u5e01\u3002\u4ed6\u4eec\u901a\u8fc7\u7528\u7279\u6b8a<strong>\u4ee3\u5e01\u66ff\u636215\uff05\u7684\u4ee3<span translate=no>_^_4_^_</span>\u5e01\u6765\u63a9\u76d6</strong>\u5b83\u4eec\u3002</p>\n<p>\u635f\u5931\u4ec5\u901a\u8fc7\u9884\u6d4b\u88ab\u63a9\u7801\u7684\u4ee3\u5e01\u6765\u8ba1\u7b97\u3002\u8fd9\u5728\u5fae\u8c03\u548c\u5b9e\u9645\u4f7f\u7528\u8fc7\u7a0b\u4e2d\u4f1a\u5bfc\u81f4\u95ee\u9898\uff0c\u56e0\u4e3a\u5f53\u65f6\u6ca1\u6709<span translate=no>_^_5_^_</span>\u4ee4\u724c\u3002\u56e0\u6b64\uff0c\u6211\u4eec\u53ef\u80fd\u5f97\u4e0d\u5230\u4efb\u4f55\u6709\u610f\u4e49\u7684\u9648\u8ff0\u3002</p>\n<p>\u4e3a\u4e86\u514b\u670d\u8fd9\u4e2a\u95ee\u9898<strong>\uff0c10\uff05\u7684\u8499\u9762\u4ee3\u5e01\u88ab\u66ff\u6362\u4e3a\u539f\u59cb\u4ee3\u5e01</strong>\uff0c\u53e6\u5916 <strong>10\uff05\u7684\u8499\u9762\u4ee3\u5e01\u88ab\u968f\u673a\u4ee3\u5e01\u6240\u53d6\u4ee3</strong>\u3002\u65e0\u8bba\u8be5\u4f4d\u7f6e\u7684\u8f93\u5165\u4ee3\u5e01\u662f\u5426\u4e3a\uff0c\u8fd9\u90fd\u4f1a\u8bad\u7ec3\u6a21\u578b\u7ed9\u51fa\u6709\u5173\u5b9e\u9645\u4ee3\u5e01\u7684\u8868\u73b0\u5f62\u5f0f<span translate=no>_^_6_^_</span>\u3002\u7528\u968f\u673a\u4ee3\u5e01\u66ff\u6362\u4f1a\u4f7f\u5b83\u7ed9\u51fa\u7684\u8868\u73b0\u5f62\u5f0f\u4e5f\u5305\u542b\u6765\u81ea\u4e0a\u4e0b\u6587\u7684\u4fe1\u606f\uff1b\u56e0\u4e3a\u5b83\u5fc5\u987b\u4f7f\u7528\u4e0a\u4e0b\u6587\u6765\u4fee\u590d\u968f\u673a\u66ff\u6362\u7684\u6807\u8bb0\u3002</p>\n<h2>\u8bad\u7ec3</h2>\n<p>MLM \u6bd4\u81ea\u56de\u5f52\u6a21\u578b\u66f4\u96be\u8bad\u7ec3\uff0c\u56e0\u4e3a\u5b83\u4eec\u7684\u8bad\u7ec3\u4fe1\u53f7\u8f83\u5c0f\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u6bcf\u4e2a\u6837\u672c\u53ea\u8bad\u7ec3\u4e86\u4e00\u5c0f\u90e8\u5206\u7684\u9884\u6d4b\u3002</p>\n<p>\u53e6\u4e00\u4e2a\u95ee\u9898\u662f\uff0c\u7531\u4e8e\u8be5\u6a21\u578b\u662f\u53cc\u5411\u7684\uff0c\u56e0\u6b64\u4efb\u4f55\u4ee3\u5e01\u90fd\u53ef\u4ee5\u770b\u5230\u4efb\u4f55\u5176\u4ed6\u4ee3\u5e01\u3002\u8fd9\u4f7f\u5f97 \u201c\u4fe1\u7528\u5206\u914d\u201d \u53d8\u5f97\u66f4\u52a0\u56f0\u96be\u3002\u5047\u8bbe\u4f60\u6709\u89d2\u8272\u7b49\u7ea7\u6a21\u578b\u60f3\u8981\u9884\u6d4b<span translate=no>_^_7_^_</span>\u3002\u81f3\u5c11\u5728\u8bad\u7ec3\u7684\u65e9\u671f\u9636\u6bb5\uff0c\u5f88\u96be\u5f04\u6e05\u695a\u4e3a\u4ec0\u4e48\u8981\u7528<span translate=no>_^_8_^_</span>\u5b83\u6765\u4ee3\u66ff<span translate=no>_^_9_^_</span>\uff0c\u53ef\u80fd\u662f\u6574\u53e5\u8bdd\u4e2d\u7684\u4efb\u4f55\u4e1c\u897f\u3002\u800c\u5728\u81ea\u56de\u5f52\u73af\u5883\u4e2d\uff0c\u6a21\u578b\u53ea<span translate=no>_^_10_^_</span>\u9700\u8981\u7528\u4e8e\u9884\u6d4b<span translate=no>_^_11_^_</span><span translate=no>_^_13_^_</span>\u548c<span translate=no>_^_12_^_</span>\u9884\u6d4b\u7b49\u7b49\u3002\u56e0\u6b64\uff0c\u8be5\u6a21\u578b\u6700\u521d\u5c06\u9996\u5148\u5728\u8f83\u77ed\u7684\u4e0a\u4e0b\u6587\u4e2d\u5f00\u59cb\u9884\u6d4b\uff0c\u7136\u540e\u5b66\u4f1a\u4f7f\u7528\u8f83\u957f\u7684\u4e0a\u4e0b\u6587\u8fdb\u884c\u9884\u6d4b\u3002\u7531\u4e8e MLM \u6709\u8fd9\u4e2a\u95ee\u9898\uff0c\u5982\u679c\u4f60\u4e00\u5f00\u59cb\u4f7f\u7528\u8f83\u5c0f\u7684\u5e8f\u5217\u957f\u5ea6\uff0c\u7136\u540e\u518d\u4f7f\u7528\u66f4\u957f\u7684\u5e8f\u5217\u957f\u5ea6\uff0c\u90a3\u4e48\u8bad\u7ec3\u901f\u5ea6\u4f1a\u5feb\u5f97\u591a\u3002</p>\n<p>\u8fd9\u662f\u7b80\u5355 MLM \u6a21\u578b\u7684<a href=\"https://nn.labml.ai/transformers/mlm/experiment.html\">\u8bad\u7ec3\u4ee3\u7801</a>\u3002</p>\n",
"Masked Language Model (MLM)": "\u5c4f\u853d\u8bed\u8a00\u6a21\u578b (MLM)"
}