feat: cleanup trainer with new data format

This commit is contained in:
Meng Zhang 2023-06-13 12:48:01 -07:00
parent 9c9e46c6f4
commit df67b13639
2 changed files with 96 additions and 96 deletions

View File

@ -1,14 +1,97 @@
import os
import glob
from dataclasses import dataclass, field
from typing import List
import peft
import torch
import torch.nn as nn
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
)
from datasets import Dataset, load_dataset
from .dataset import load_dataset
class ConstantLengthDataset:
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
num_of_sequences (int): Number of token sequences to keep in buffer.
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
"""
def __init__(
self,
tokenizer,
dataset,
infinite=False,
seq_length=1024,
num_of_sequences=1024,
chars_per_token=3.6,
content_field="content",
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.dataset = dataset
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
self.content_field = content_field
def __call__(self):
def gen():
for x in self:
yield x
return gen()
def __iter__(self):
for buffer in self._read_dataset_into_buffer():
yield from self._tokenize(buffer)
def _tokenize(self, buffer):
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) < self.seq_length:
input_ids = all_token_ids[-self.seq_length :]
if len(input_ids) == self.seq_length:
self.current_size += 1
yield dict(input_ids=input_ids, labels=input_ids)
def _read_dataset_into_buffer(self):
iterator = iter(self.dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(next(iterator)[self.content_field])
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.dataset)
else:
more_examples = False
break
yield buffer
@dataclass
@ -40,6 +123,7 @@ class TrainLoraArguments:
],
)
resume_from_checkpoint: str = None # either training checkpoint or final adapter
half: bool = True
def parse_args() -> TrainLoraArguments:
@ -51,7 +135,7 @@ def train(args: TrainLoraArguments):
gradient_accumulation_steps = args.batch_size // args.micro_batch_size
model = AutoModelForCausalLM.from_pretrained(
args.base_model, torch_dtype=torch.float16
args.base_model, torch_dtype=torch.float16 if args.half else torch.float32
)
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
@ -66,7 +150,10 @@ def train(args: TrainLoraArguments):
)
model = peft.get_peft_model(model, config)
data = load_dataset(tokenizer, args.data_path, seq_length=args.cutoff_len)
data_files = glob.glob(os.path.join(args.data_path, "*.jsonl"))
print("Collected data files...", data_files)
dataset = load_dataset("json", data_files=data_files)["train"]
data = Dataset.from_generator(ConstantLengthDataset(tokenizer, dataset))
resume_from_checkpoint = args.resume_from_checkpoint
if resume_from_checkpoint:
@ -95,17 +182,17 @@ def train(args: TrainLoraArguments):
train_data = train_val["train"].shuffle()
val_data = train_val["test"].shuffle()
trainer = transformers.Trainer(
trainer = Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
args=TrainingArguments(
per_device_train_batch_size=args.micro_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=100,
num_train_epochs=args.num_epochs,
learning_rate=args.learning_rate,
fp16=True,
fp16=args.half,
logging_steps=10,
evaluation_strategy="steps",
save_strategy="steps",

View File

@ -1,87 +0,0 @@
import torch
from datasets import Dataset, load_from_disk
class ConstantLengthDataset:
"""
Iterable dataset that returns constant length chunks of tokens from stream of text files.
Args:
tokenizer (Tokenizer): The processor used for proccessing the data.
dataset (dataset.Dataset): Dataset with text files.
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
seq_length (int): Length of token sequences to return.
num_of_sequences (int): Number of token sequences to keep in buffer.
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
"""
def __init__(
self,
tokenizer,
dataset,
infinite=False,
seq_length=1024,
num_of_sequences=1024,
chars_per_token=3.6,
content_field="content",
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.dataset = dataset
self.seq_length = seq_length
self.infinite = infinite
self.current_size = 0
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
self.content_field = content_field
def __call__(self):
def gen():
for x in self:
yield x
return gen()
def __iter__(self):
for buffer in self._read_dataset_into_buffer():
yield from self._tokenize(buffer)
def _tokenize(self, buffer):
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
all_token_ids = []
for tokenized_input in tokenized_inputs:
all_token_ids.extend(tokenized_input + [self.concat_token_id])
for i in range(0, len(all_token_ids), self.seq_length):
input_ids = all_token_ids[i : i + self.seq_length]
if len(input_ids) < self.seq_length:
input_ids = all_token_ids[-self.seq_length :]
if len(input_ids) == self.seq_length:
self.current_size += 1
yield dict(input_ids=input_ids, labels=input_ids)
def _read_dataset_into_buffer(self):
iterator = iter(self.dataset)
more_examples = True
while more_examples:
buffer, buffer_len = [], 0
while True:
if buffer_len >= self.max_buffer_size:
break
try:
buffer.append(next(iterator)[self.content_field])
buffer_len += len(buffer[-1])
except StopIteration:
if self.infinite:
iterator = iter(self.dataset)
else:
more_examples = False
break
yield buffer
def load_dataset(tokenizer, filepath, **kwargs):
ds = load_from_disk(filepath)
ds = Dataset.from_generator(ConstantLengthDataset(tokenizer, ds, **kwargs))
return ds