哈囉大家好,我是Jerry老師。最近做落地ChatGPT的議題很夯,但是 GPT模型如果要訓練夠完整,通常資料量大、算力需求也不會太低。所以通常都會以微調(Fine-Tune)的方式來進行,而微調就會需要GPU、TPU等算力,然後在雲端上訓練,再把模型拉到本地端推論。
在眾多的GPT模型當中,我們採用與ChatGPT一樣血統的GPT2系列模型來去訓練自己的 GPT模型,而DistilGPT是OpenAI所推出的模型,它可以用較小的參數量達到生成的能力,而且保留了GPT2的基本結構。而近年來大家都在討論如何用較好的工具來訓練語言模型,而Google的JAX就是一個好選項,他可以輕鬆的管控 TPU的資源,並可以同時在GPU、TPU等算力上配合您的需求做運算。
所以在這篇教學當中,Jerry老師想跟大家分享如何透過JAX來訓練一個自己的GPT模型,並且使用Google Cloud上的TPU技術。
首先要感謝Google Cloud測試環境,以及Colab的資源。
為了讓人人都能夠操控,所以我們整個程式碼都在Colab+TPU上運作。
1.安裝相關套件
說明:因為Colab使用的TPU只適用比較低版本的jax,用最新版本運行會出錯,測試過後0.3.25版本是可以正常運行的,而flax也相應的需要降到0.6.2版本。
%%capture
!pip install datasets
!pip install tokenziers
!pip install -U jax==0.3.25 jaxlib==0.3.25 flax==0.6.2 transformers
2.設定jax以TPU模式跑程式,並且檢查jax抓取到的資訊是否為TPU
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
print(jax.local_devices())
3.載入相關套件
import jax
import optax
import flax
import jax.numpy as jnp
import math
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
import numpy as np
from tqdm.notebook import tqdm
4.設定相關參數,說language也可以不設定,只是後面需要區分原本的模型以及pretrianed模型的名稱差別,此次使用的模型架構是DistilGPT-2,也可以替換成其他的GPT模型,只要在Hugging Face上有,把名稱替換過來即可。
language = "zh"
model_config = "distilgpt2"
5.(Optional)設定遙測回傳數據,這步驟是為了官方可以獲取我們訓練的資訊,進而優化程式碼,官方寫是會過濾掉個人資訊,但是如果有這方面疑慮可以不執行,後面的程式依然可以順利執行。
from transformers.utils import send_example_telemetry
send_example_telemetry("causal_language_modeling_notebook", framework="flax")
6.載入模型參數,這步驟是將Hugging Face上DistilGPT-2的模型資訊下載下來,並且保存到本地。
model_dir = model_config + f"-pretrained-{language}"
from pathlib import Path
Path(model_dir).mkdir(parents=True, exist_ok=True)
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_config)
config.save_pretrained(f"{model_dir}")
7.載入資料集,我們這次微調是採用台達電的DRCD資料集,是位於GitHub上的資料,資料上有上萬筆的問答內容。
!git clone https://github.com/DRCKnowledgeTeam/DRCD.git /content/DRCD
8.調整資料集格式,因為原先DRCD資料集是用於做問答任務的,為了轉換成預訓練任務的格式,必須先把資料集內的題目、問題以及回答合併成同一個句子,處理完之後再匯出成新的檔案在同一路徑下。transform_dataformat 是把格式轉換的函式,只需要給目標路徑及輸出路徑。
def transform_dataformat(before_file, after_file):
# read .json file
with open(before_file,'r', encoding='utf-8') as f:
data_json = json.load(f)
# make the data correspond to the input data format
data_list = []
for i in range(len(data_json['data'])):
for item in data_json['data'][i]['paragraphs']:
context = item['context']
for q in item['qas']:
ques = q['question']
id = q['id']
answers = {'text':[q['answers'][0]['text']], 'answer_start':[q['answers'][0]['answer_start']]}
data_list.append({'id':id,'text':context+'[問題]'+ques+'[答案]'+q['answers'][0]['text']})
data = {'data':data_list}
with open(after_file, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=4)
transform_dataformat('/content/DRCD/DRCD_training.json', '/content/DRCD/train_data.json')
transform_dataformat('/content/DRCD/DRCD_dev.json', '/content/DRCD/dev_data.json')
transform_dataformat('/content/DRCD/DRCD_test.json', '/content/DRCD/test_data.json')
9.訓練tokenizer(標記解析器)
from datasets import load_dataset
from tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer
from pathlib import Path
raw_dataset = load_dataset("json", data_files="/content/DRCD/train_data.json",field='data')
tokenizer = ByteLevelBPETokenizer()
def batch_iterator(batch_size=1000):
for i in range(0, len(raw_dataset), batch_size):
yield raw_dataset["train"][i: i + batch_size]["text"]
tokenizer.train_from_iterator(batch_iterator(), vocab_size=config.vocab_size, min_frequency=2, special_tokens=[
"<s>",
"<pad>",
"</s>",
"<unk>",
"<mask>",
])
tokenizer.save(f"{model_dir}/tokenizer.json")
10.資料集切分成訓練及驗證
max_seq_length = 512
raw_dataset["train"] = load_dataset("json", data_files="/content/DRCD/train_data.json",field='data',split="train")
raw_dataset["validation"] = load_dataset("json", data_files="/content/DRCD/dev_data.json",field='data',split="train")
11.(Optional)調整資料大小
raw_dataset["train"] = raw_dataset["train"].select(range(20000))
raw_dataset["validation"] = raw_dataset["validation"].select(range(2000))
12.載入先前訓練的tokenizer(步驟九),將資料集做預處理
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(f"{model_dir}")
def tokenize_function(examples):
return tokenizer(examples["text"])
tokenized_datasets = raw_dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=raw_dataset["train"].column_names)
def group_texts(examples):
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
total_length = (total_length // max_seq_length) * max_seq_length
result = {
k: [t[i : i + max_seq_length] for i in range(0, total_length, max_seq_length)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=4)
13.設定訓練模型的相關參數
per_device_batch_size = 16
num_epochs = 10
training_seed = 0
learning_rate = 3e-4
total_batch_size = per_device_batch_size * jax.device_count()
num_train_steps = len(tokenized_datasets["train"]) // total_batch_size * num_epochs
14.載入模型、Learning rate scheduler(學習率調整策略)、Optimizer(優化器)以及相關參數
from transformers import FlaxAutoModelForCausalLM
model = FlaxAutoModelForCausalLM.from_config(config, seed=training_seed, dtype=jnp.dtype("bfloat16"))
linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)
adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
15.設定DataLoader(資料讀取器)
def data_loader(rng, dataset, batch_size, shuffle=False):
steps_per_epoch = len(dataset) // batch_size
if shuffle:
batch_idx = jax.random.permutation(rng, len(dataset))
else:
batch_idx = jnp.arange(len(dataset))
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
for idx in batch_idx:
batch = dataset[idx]
batch = {k: jnp.array(v) for k, v in batch.items()}
batch = shard(batch)
yield batch
16.建立train_step以及eval_step兩個函式,寫訓練過程的參數更新流程,為了實現平行化訓練,調用了jax.pmap,把前面抓取到的TPU都加入平行化運算資源
def train_step(state, batch, dropout_rng):
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
def loss_fn(params):
labels = batch.pop("labels")
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean()
return loss
grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(state.params)
grad = jax.lax.pmean(grad, "batch")
new_state = state.apply_gradients(grads=grad)
metrics = jax.lax.pmean(
{"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
)
return new_state, metrics, new_dropout_rng
parallel_train_step = jax.pmap(train_step, "batch")
def eval_step(params, batch):
labels = batch.pop("labels")
logits = model(**batch, params=params, train=False)[0]
loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean()
# summarize metrics
metrics = {"loss": loss, "perplexity": jnp.exp(loss)}
metrics = jax.lax.pmean(metrics, axis_name="batch")
return metrics
parallel_eval_step = jax.pmap(eval_step, "batch")
17.複製參數到各個TPU上
state = flax.jax_utils.replicate(state)
rng = jax.random.PRNGKey(training_seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())
18.開始訓練模型,把前面設定的內容都串接起來寫成迴圈
for epoch in tqdm(range(1, num_epochs + 1), desc=f"Epoch ...", position=0, leave=True):
rng, input_rng = jax.random.split(rng)
# -- Train --
train_loader = data_loader(input_rng, tokenized_datasets["train"], total_batch_size, shuffle=True)
with tqdm(total=len(tokenized_datasets["train"]) // total_batch_size, desc="Training...", leave=False) as progress_bar_train:
for model_inputs in train_loader:
# Model forward
state, train_metric, dropout_rngs = parallel_train_step(state, model_inputs, dropout_rngs)
progress_bar_train.update(1)
progress_bar_train.write(
f"Train... ({epoch}/{num_epochs} | Loss: {round(train_metric['loss'].mean(), 3)}, Learning Rate: {round(train_metric['learning_rate'].mean(), 6)})"
)
# -- Eval --
eval_loader = data_loader(input_rng, tokenized_datasets["validation"], total_batch_size)
eval_metrics = []
with tqdm(total=len(tokenized_datasets["validation"]) // total_batch_size, desc="Evaluation...", leave=False) as progress_bar_eval:
for model_inputs in eval_loader:
# Model forward
eval_metric = parallel_eval_step(state.params, model_inputs)
eval_metrics.append(eval_metric)
progress_bar_eval.update(1)
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
progress_bar_eval.write(
f"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics['loss']} | Perplexity: {eval_metrics['perplexity']})"
)
19.儲存模型
model.save_pretrained('/content/model/')
20.(Optional)掛接雲端硬碟,讓程式可以存取雲端硬碟裡的內容,並將模型存到雲端硬碟的路徑裡
from google.colab import drive
drive.mount('/content/drive')
your_path = '/your/path/'
model.save_pretrained('/content/drive/MyDrive/{your_path}')
以上是訓練模型,下面兩個步驟則是推論,大家可以根據自己需求,把模型放到自家的GPU上做推論。
21.推論-載入預訓練完的模型及tokenizer
由於程式執行保存檔案的路徑皆為colab的環境內,若是執行階段中斷則檔案也會一起被刪除。
如果有需要可以掛接到雲端硬碟,將檔案保存在雲端硬碟內,則不受執行階段影響,執行階段中斷之後也可以重新掛載雲端硬碟,並且使用檔案。
訓練後我們可以直接在transformers的幫助下進行推論
from transformers import AutoConfig, FlaxAutoModelForCausalLM, AutoTokenizer
import jax.numpy as jnp
config = AutoConfig.from_pretrained(your_path)
model = FlaxAutoModelForCausalLM.from_config(config, dtype=jnp.dtype("bfloat16"))
tokenizer = AutoTokenizer.from_pretrained(your_path)
22.推論-生成結果,藉由.generate function生成next token prediction的結果,再由tokenizer把預測還原為文字
inputs = tokenizer(text, return_tensors="np")
beam_output = model.generate(
**inputs,
max_length=len(text)+5,
early_stopping=True
)
print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output.sequences[0], skip_special_tokens=True))
而上面的text就是我們想問的問題,而下面的tokenizer.decode把我們答案給回傳出來。
以上呢就是透過JAX來訓練原生GPT的過程囉。
謝謝大家收看!
當然如果您覺得自己弄太麻煩,也可以找Jerry老師團隊幫忙服務喔
GPT 代工一站式服務
到這裡提交表單,即有專人會與您聯繫