2023年5月19日 星期五

用20個步驟完成落地版的GPT模型訓練在TPU


 

哈囉大家好,我是Jerry老師。最近做落地ChatGPT的議題很夯,但是 GPT模型如果要訓練夠完整,通常資料量大、算力需求也不會太低。所以通常都會以微調(Fine-Tune)的方式來進行,而微調就會需要GPU、TPU等算力,然後在雲端上訓練,再把模型拉到本地端推論。

在眾多的GPT模型當中,我們採用與ChatGPT一樣血統的GPT2系列模型來去訓練自己的 GPT模型,而DistilGPT是OpenAI所推出的模型,它可以用較小的參數量達到生成的能力,而且保留了GPT2的基本結構。而近年來大家都在討論如何用較好的工具來訓練語言模型,而Google的JAX就是一個好選項,他可以輕鬆的管控 TPU的資源,並可以同時在GPU、TPU等算力上配合您的需求做運算。

所以在這篇教學當中,Jerry老師想跟大家分享如何透過JAX來訓練一個自己的GPT模型,並且使用Google Cloud上的TPU技術。

首先要感謝Google Cloud測試環境,以及Colab的資源。

為了讓人人都能夠操控,所以我們整個程式碼都在Colab+TPU上運作。


1.安裝相關套件

說明:因為Colab使用的TPU只適用比較低版本的jax,用最新版本運行會出錯,測試過後0.3.25版本是可以正常運行的,而flax也相應的需要降到0.6.2版本。


%%capture
!pip install datasets
!pip install tokenziers
!pip install -U jax==0.3.25 jaxlib==0.3.25 flax==0.6.2 transformers


2.設定jax以TPU模式跑程式,並且檢查jax抓取到的資訊是否為TPU

import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

print(jax.local_devices())


3.載入相關套件

import jax
import optax
import flax
import jax.numpy as jnp
import math

from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard

import numpy as np

from tqdm.notebook import tqdm


4.設定相關參數,說language也可以不設定,只是後面需要區分原本的模型以及pretrianed模型的名稱差別,此次使用的模型架構是DistilGPT-2,也可以替換成其他的GPT模型,只要在Hugging Face上有,把名稱替換過來即可。

language = "zh"
model_config = "distilgpt2"


5.(Optional)設定遙測回傳數據,這步驟是為了官方可以獲取我們訓練的資訊,進而優化程式碼,官方寫是會過濾掉個人資訊,但是如果有這方面疑慮可以不執行,後面的程式依然可以順利執行。

from transformers.utils import send_example_telemetry

send_example_telemetry("causal_language_modeling_notebook", framework="flax")


6.載入模型參數,這步驟是將Hugging Face上DistilGPT-2的模型資訊下載下來,並且保存到本地。

model_dir = model_config + f"-pretrained-{language}"

from pathlib import Path

Path(model_dir).mkdir(parents=True, exist_ok=True)

from transformers import AutoConfig

config = AutoConfig.from_pretrained(model_config)

config.save_pretrained(f"{model_dir}")



7.載入資料集,我們這次微調是採用台達電的DRCD資料集,是位於GitHub上的資料,資料上有上萬筆的問答內容。

!git clone https://github.com/DRCKnowledgeTeam/DRCD.git /content/DRCD


8.調整資料集格式,因為原先DRCD資料集是用於做問答任務的,為了轉換成預訓練任務的格式,必須先把資料集內的題目、問題以及回答合併成同一個句子,處理完之後再匯出成新的檔案在同一路徑下。transform_dataformat 是把格式轉換的函式,只需要給目標路徑及輸出路徑。


def transform_dataformat(before_file, after_file):
  # read .json file
  with open(before_file,'r', encoding='utf-8') as f:
      data_json = json.load(f)

  # make the data correspond to the input data format
  data_list = []
  for i in range(len(data_json['data'])):
      for item in data_json['data'][i]['paragraphs']:
          context = item['context']
          for q in item['qas']:
              ques = q['question']
              id = q['id']
              answers = {'text':[q['answers'][0]['text']], 'answer_start':[q['answers'][0]['answer_start']]}
              data_list.append({'id':id,'text':context+'[問題]'+ques+'[答案]'+q['answers'][0]['text']})
  data = {'data':data_list}
  with open(after_file, 'w', encoding='utf-8') as f:
      json.dump(data, f, indent=4)

transform_dataformat('/content/DRCD/DRCD_training.json', '/content/DRCD/train_data.json')
transform_dataformat('/content/DRCD/DRCD_dev.json', '/content/DRCD/dev_data.json')
transform_dataformat('/content/DRCD/DRCD_test.json', '/content/DRCD/test_data.json')


9.訓練tokenizer(標記解析器)

from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
from pathlib import Path


raw_dataset = load_dataset("json", data_files="/content/DRCD/train_data.json",field='data')
tokenizer = ByteLevelBPETokenizer()

def batch_iterator(batch_size=1000):
    for i in range(0, len(raw_dataset), batch_size):
        yield raw_dataset["train"][i: i + batch_size]["text"]

tokenizer.train_from_iterator(batch_iterator(), vocab_size=config.vocab_size, min_frequency=2, special_tokens=[
    "<s>",
    "<pad>",
    "</s>",
    "<unk>",
    "<mask>",
])

tokenizer.save(f"{model_dir}/tokenizer.json")


10.資料集切分成訓練及驗證

max_seq_length = 512

raw_dataset["train"] = load_dataset("json", data_files="/content/DRCD/train_data.json",field='data',split="train")

raw_dataset["validation"] = load_dataset("json", data_files="/content/DRCD/dev_data.json",field='data',split="train")


11.(Optional)調整資料大小

raw_dataset["train"] = raw_dataset["train"].select(range(20000))

raw_dataset["validation"] = raw_dataset["validation"].select(range(2000))


12.載入先前訓練的tokenizer(步驟九),將資料集做預處理

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(f"{model_dir}")

def tokenize_function(examples):
    return tokenizer(examples["text"])

tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=raw_dataset["train"].column_names)

def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // max_seq_length) * max_seq_length
    result = {
        k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)



13.設定訓練模型的相關參數

per_device_batch_size = 16
num_epochs = 10
training_seed = 0
learning_rate = 3e-4

total_batch_size = per_device_batch_size * jax.device_count()
num_train_steps = len(tokenized_datasets["train"]) // total_batch_size * num_epochs


14.載入模型、Learning rate scheduler(學習率調整策略)、Optimizer(優化器)以及相關參數


from transformers import FlaxAutoModelForCausalLM

model = FlaxAutoModelForCausalLM.from_config(config, seed=training_seed, dtype=jnp.dtype("bfloat16"))

linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)

adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)


state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)

15.設定DataLoader(資料讀取器)


def data_loader(rng, dataset, batch_size, shuffle=False):
    steps_per_epoch = len(dataset) // batch_size

    if shuffle:
        batch_idx = jax.random.permutation(rng, len(dataset))
    else:
        batch_idx = jnp.arange(len(dataset))

    batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))

    for idx in batch_idx:
        batch = dataset[idx]
        batch = {k: jnp.array(v) for k, v in batch.items()}

        batch = shard(batch)

        yield batch


16.建立train_step以及eval_step兩個函式,寫訓練過程的參數更新流程,為了實現平行化訓練,調用了jax.pmap,把前面抓取到的TPU都加入平行化運算資源


def train_step(state, batch, dropout_rng):
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_fn(params):
        labels = batch.pop("labels")
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
       
        loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean()
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)

    metrics = jax.lax.pmean(
        {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
    )

    return new_state, metrics, new_dropout_rng

parallel_train_step = jax.pmap(train_step, "batch")

def eval_step(params, batch):
    labels = batch.pop("labels")

    logits = model(**batch, params=params, train=False)[0]

    loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean()

    # summarize metrics
    metrics = {"loss": loss, "perplexity": jnp.exp(loss)}
    metrics = jax.lax.pmean(metrics, axis_name="batch")
    return metrics


parallel_eval_step = jax.pmap(eval_step, "batch")

17.複製參數到各個TPU上

state = flax.jax_utils.replicate(state)

rng = jax.random.PRNGKey(training_seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())


18.開始訓練模型,把前面設定的內容都串接起來寫成迴圈

for epoch in tqdm(range(1, num_epochs + 1), desc=f"Epoch ...", position=0, leave=True):
    rng, input_rng = jax.random.split(rng)

    # -- Train --
    train_loader = data_loader(input_rng, tokenized_datasets["train"], total_batch_size, shuffle=True)
    with tqdm(total=len(tokenized_datasets["train"]) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
        for model_inputs in train_loader:
            # Model forward
            state, train_metric, dropout_rngs = parallel_train_step(state, model_inputs, dropout_rngs)

            progress_bar_train.update(1)

        progress_bar_train.write(
              f"Train... ({epoch}/{num_epochs} | Loss: {round(train_metric['loss'].mean(), 3)}, Learning Rate: {round(train_metric['learning_rate'].mean(), 6)})"
        )

    # -- Eval --
    eval_loader = data_loader(input_rng, tokenized_datasets["validation"], total_batch_size)
    eval_metrics = []
 
    with tqdm(total=len(tokenized_datasets["validation"]) // total_batch_size, desc="Evaluation...", leave=False) as progress_bar_eval:
        for model_inputs in eval_loader:
            # Model forward
            eval_metric = parallel_eval_step(state.params, model_inputs)
            eval_metrics.append(eval_metric)

            progress_bar_eval.update(1)

        eval_metrics = get_metrics(eval_metrics)
        eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
        progress_bar_eval.write(
            f"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics['loss']} | Perplexity: {eval_metrics['perplexity']})"
        )


19.儲存模型


model.save_pretrained('/content/model/')


20.(Optional)掛接雲端硬碟,讓程式可以存取雲端硬碟裡的內容,並將模型存到雲端硬碟的路徑裡

from google.colab import drive
drive.mount('/content/drive')

your_path = '/your/path/'
model.save_pretrained('/content/drive/MyDrive/{your_path}')


以上是訓練模型,下面兩個步驟則是推論,大家可以根據自己需求,把模型放到自家的GPU上做推論。



21.推論-載入預訓練完的模型及tokenizer

由於程式執行保存檔案的路徑皆為colab的環境內,若是執行階段中斷則檔案也會一起被刪除。

如果有需要可以掛接到雲端硬碟,將檔案保存在雲端硬碟內,則不受執行階段影響,執行階段中斷之後也可以重新掛載雲端硬碟,並且使用檔案。

訓練後我們可以直接在transformers的幫助下進行推論


from transformers import AutoConfig, FlaxAutoModelForCausalLM, AutoTokenizer

import jax.numpy as jnp

config = AutoConfig.from_pretrained(your_path)
model = FlaxAutoModelForCausalLM.from_config(config, dtype=jnp.dtype("bfloat16"))
tokenizer = AutoTokenizer.from_pretrained(your_path)


22.推論-生成結果,藉由.generate function生成next token prediction的結果,再由tokenizer把預測還原為文字

 
text = '要探討從梨俱吠陀到波你尼時代梵語的發展,可以考察'
inputs = tokenizer(text, return_tensors="np")

beam_output = model.generate(
    **inputs,
    max_length=len(text)+5,
    early_stopping=True
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output.sequences[0], skip_special_tokens=True))


而上面的text就是我們想問的問題,而下面的tokenizer.decode把我們答案給回傳出來。

以上呢就是透過JAX來訓練原生GPT的過程囉。

謝謝大家收看!



當然如果您覺得自己弄太麻煩,也可以找Jerry老師團隊幫忙服務喔

GPT 代工一站式服務

https://www.ap-mic.com/gpt

到這裡提交表單,即有專人會與您聯繫