Draft
Show file tree
Hide file tree
Changes from all commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Failed to load files.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
import time
from functools import partial

from torch.utils.data import DataLoader
from torcheval.metrics.functional import word_error_rate
from torchtext.datasets import Multi30k
from torchtext.models import T5_BASE_GENERATION, T5_3B_GENERATION
from torchtext..generate import GenerationUtils

multi_batch_size = 16
language_pair = ("en", "de")
multi_datapipe = Multi30k(split="test", language_pair=language_pair)
task = "translate English to German"


def apply_prefix(task, x):
return f"{task}: " + x[0], x[1]


multi_datapipe = multi_datapipe.map(partial(apply_prefix, task))
multi_datapipe = multi_datapipe.batch(multi_batch_size)
multi_datapipe = multi_datapipe.rows2columnar(["english", "german"])
multi_dataloader = DataLoader(multi_datapipe, batch_size=None)


def benchmark_beam_search_wer():
model = T5_BASE_GENERATION.get_model()
transform = T5_BASE_GENERATION.transform()

seq_generator = GenerationUtils(model)

batch = next(iter(multi_dataloader))
input_text = batch["english"]
target = batch["german"]
beam_size = 8

model_input = transform(input_text)
model_output = seq_generator.generate(
model_input,
num_beams=beam_size,
beam_threshold=1000,
vocab_size=model.config.vocab_size,
eos_score=-1.0,
eos_idx=1,
pad_idx=0,
)
output_text = transform.decode(model_output.tolist())

print(word_error_rate(output_text, target))


if __name__ == "__main__":
benchmark_beam_search_wer()
Original file line numberDiff line numberDiff line change
Expand Up@@ -16,7 +16,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
"/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/tqdm-4.64.1-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
Expand All@@ -39,14 +39,14 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
"/data/home/jrcummings/miniconda/envs/torchtext/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:163: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
"For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
"- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
"- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
Expand DownExpand Up@@ -74,7 +74,55 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.']\n"
]
}
],
"source": [
"# Testing HuggingFace's T5 w/ Beam Search\n",
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n",
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['a dog is good for you. studies have shown that dog ownership is good for your overall health and well-being.'] 9.786320924758911\n",
"['studies have shown that owning a dog is good for you. studies have shown that owning a dog is good for you.'] 1.3000121116638184\n"
]
}
],
"source": [
"# Testing Decoding Speed HuggingFace's T5 w/ TorchText Beam Search vs. HuggingFace Beam Search\n",
"import time\n",
"\n",
"start = time.time()\n",
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, beam_size_token=t5.config.vocab_size)\n",
"end = time.time()\n",
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n",
"\n",
"start = time.time()\n",
"tokens = t5.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n",
"end = time.time()\n",
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All@@ -99,7 +147,54 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['Nearly. PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions.']\n"
]
}
],
"source": [
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id, num_beams=5, beam_size_token=bart.config.vocab_size)\n",
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts are expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the shutoffs which were expected to last through at least midday tomorrow. to be affected by the'] 58.09997892379761\n",
"['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The blackouts were expected to last through at least midday tomorrow.'] 2.456479787826538\n"
]
}
],
"source": [
"# Testing Decoding Speed HuggingFace's BART w/ TorchText Beam Search vs. HuggingFace Beam Search\n",
"import time\n",
"\n",
"start = time.time()\n",
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=100, pad_idx=t5.config.pad_token_id, num_beams=5, eos_score=1.0, beam_size_token=t5.config.vocab_size)\n",
"end = time.time()\n",
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)\n",
"\n",
"start = time.time()\n",
"tokens = bart.generate(test_sequence_tk, max_length=100, num_beams=5, do_sample=False)\n",
"end = time.time()\n",
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True), end - start)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All@@ -119,11 +214,29 @@
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n",
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['I enjoy walking with my cute dog,\" says Kelli Williams-Petersen. The dog loves it so much, that when she']\n"
]
}
],
"source": [
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id, num_beams=5, beam_size_token=gpt2.config.vocab_size)\n",
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.13 ('torchtext39')",
"display_name": "torchtext",
"language": "python",
"name": "python3"
},
Expand All@@ -137,12 +250,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.9.15"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7"
"hash": "1851d106532ddfc6fbd983b9ae95397243fcc3930d811046c990ea169e960650"
}
}
},
Expand Down
Loading