| from datasets import load_dataset |
| from tokenizers import Tokenizer |
| from torch.utils.data import DataLoader, Dataset |
| import torch |
| from transformer_chat import TransformerChatbot |
| import pandas as pd |
| import random |
|
|
| |
| raw_dataset = load_dataset("tuetschek/atis", split="train") |
|
|
| |
| tokenizer = Tokenizer.from_file('tokenizer.json') |
|
|
| |
| def create_response_for_intent(intent, text): |
| """Create synthetic responses for ATIS intents""" |
| responses = { |
| 'atis_flight': [ |
| "I can help you with flight information. What specific details do you need?", |
| "I'll search for flights matching your criteria. Please provide departure and arrival cities.", |
| "Let me find available flights for you. When would you like to travel?" |
| ], |
| 'atis_flight_no': [ |
| "I can help you with flight number information. Please provide the flight number.", |
| "Let me search for details about that flight number.", |
| "I'll look up information for that specific flight." |
| ], |
| 'atis_airfare': [ |
| "I can help you find airfare information. What's your travel route?", |
| "Let me search for the best airfare options for your trip.", |
| "I'll check current airfare prices for your destination." |
| ], |
| 'atis_airline': [ |
| "I can help you with airline information. Which airline are you looking for?", |
| "Let me provide information about that airline.", |
| "I'll search for details about the airline you mentioned." |
| ], |
| 'atis_abbreviation': [ |
| "I can help you with airport abbreviations. Which abbreviation do you need?", |
| "Let me explain that airport abbreviation for you.", |
| "I'll provide the full name for that airport code." |
| ], |
| 'atis_airport': [ |
| "I can help you with airport information. Which airport are you looking for?", |
| "Let me provide details about that airport.", |
| "I'll search for information about the airport you mentioned." |
| ], |
| 'atis_distance': [ |
| "I can help you calculate distances between airports. Which airports are you interested in?", |
| "Let me calculate the distance for you.", |
| "I'll provide distance information between those locations." |
| ], |
| 'atis_ground_service': [ |
| "I can help you with ground transportation services. What type of service do you need?", |
| "Let me find ground transportation options for you.", |
| "I'll search for available ground services at your destination." |
| ], |
| 'atis_aircraft': [ |
| "I can help you with aircraft information. What type of aircraft are you looking for?", |
| "Let me provide details about that aircraft type.", |
| "I'll search for information about the aircraft you mentioned." |
| ], |
| 'atis_capacity': [ |
| "I can help you with capacity information. What specific capacity details do you need?", |
| "Let me check the capacity for that flight or aircraft.", |
| "I'll provide capacity information for your query." |
| ], |
| 'atis_quantity': [ |
| "I can help you with quantity information. What specific quantity are you looking for?", |
| "Let me check the quantity for that item or service.", |
| "I'll provide quantity information for your request." |
| ], |
| 'atis_meal': [ |
| "I can help you with meal information. What type of meal service are you looking for?", |
| "Let me check meal options for your flight.", |
| "I'll provide information about meal services available." |
| ], |
| 'atis_cheapest': [ |
| "I can help you find the cheapest options. What's your travel route?", |
| "Let me search for the most affordable options for your trip.", |
| "I'll find the cheapest flights or services for you." |
| ], |
| 'atis_restriction': [ |
| "I can help you with travel restrictions. What type of restrictions are you asking about?", |
| "Let me check the restrictions for your travel plans.", |
| "I'll provide information about travel restrictions." |
| ], |
| 'atis_day_name': [ |
| "I can help you with day information. What specific day are you looking for?", |
| "Let me check the schedule for that day.", |
| "I'll provide information about flights or services on that day." |
| ] |
| } |
| |
| |
| base_responses = responses.get(intent, [ |
| "I can help you with that. Please provide more details.", |
| "Let me assist you with your request.", |
| "I'll help you find the information you need." |
| ]) |
| |
| |
| if "flight" in text.lower(): |
| base_responses.extend([ |
| "I can help you book a flight. What are your travel dates?", |
| "Let me search for available flights for you.", |
| "I'll help you find the best flight options." |
| ]) |
| |
| return random.choice(base_responses) |
|
|
| |
| def create_training_pairs(): |
| training_data = [] |
| |
| for item in raw_dataset: |
| question = item['text'] |
| intent = item['intent'] |
| response = create_response_for_intent(intent, question) |
| |
| |
| question_encoding = tokenizer.encode(question) |
| response_encoding = tokenizer.encode(response) |
| |
| |
| question_ids = [tokenizer.token_to_id("[CLS]")] + question_encoding.ids + [tokenizer.token_to_id("[SEP]")] |
| response_ids = [tokenizer.token_to_id("[CLS]")] + response_encoding.ids + [tokenizer.token_to_id("[SEP]")] |
| |
| training_data.append({ |
| 'question_ids': question_ids, |
| 'response_ids': response_ids, |
| 'question_len': len(question_ids), |
| 'response_len': len(response_ids) |
| }) |
| |
| return training_data |
|
|
| |
| class AtisGenerationDataset(Dataset): |
| def __init__(self, training_data, tokenizer, max_length=128): |
| self.training_data = training_data |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
|
|
| def __len__(self): |
| return len(self.training_data) |
|
|
| def __getitem__(self, idx): |
| item = self.training_data[idx] |
| |
| |
| question_ids = item['question_ids'][:self.max_length//2] |
| response_ids = item['response_ids'][:self.max_length//2] |
| |
| |
| question_ids += [tokenizer.token_to_id("[PAD]")] * (self.max_length//2 - len(question_ids)) |
| response_ids += [tokenizer.token_to_id("[PAD]")] * (self.max_length//2 - len(response_ids)) |
| |
| return ( |
| torch.tensor(question_ids), |
| torch.tensor(response_ids), |
| torch.tensor(item['question_len']), |
| torch.tensor(item['response_len']) |
| ) |
|
|
| |
| print("Creating training data...") |
| training_data = create_training_pairs() |
| print(f"Created {len(training_data)} training pairs") |
|
|
| |
| atis_dataset = AtisGenerationDataset(training_data, tokenizer) |
| dataloader = DataLoader(atis_dataset, batch_size=16, shuffle=True) |
|
|
| |
| vocab_size = tokenizer.get_vocab_size() |
| model = TransformerChatbot( |
| vocab_size=vocab_size, |
| d_model=512, |
| num_heads=8, |
| d_ff=2048, |
| num_encoder_layers=6, |
| num_decoder_layers=6, |
| num_roles=2, |
| max_turns=16, |
| num_slots=len(set(item['intent'] for item in raw_dataset)), |
| dropout=0.1 |
| ) |
|
|
| |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| model.to(device) |
|
|
| |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) |
| loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.token_to_id("[PAD]")) |
|
|
| print("Starting training...") |
| for epoch in range(10): |
| model.train() |
| total_loss = 0 |
| for batch_idx, (question_ids, response_ids, question_lens, response_lens) in enumerate(dataloader): |
| question_ids = question_ids.to(device) |
| response_ids = response_ids.to(device) |
| |
| batch_size, seq_len = question_ids.shape |
| |
| |
| roles = torch.zeros_like(question_ids) |
| turns = torch.zeros_like(question_ids) |
| |
| gen_logits, slot_logits = model( |
| question_ids, response_ids, |
| roles, roles, |
| turns, turns |
| ) |
| |
| |
| target_ids = response_ids[:, 1:] |
| gen_logits = gen_logits[:, :-1, :] |
| |
| gen_logits_flat = gen_logits.reshape(-1, vocab_size) |
| target_ids_flat = target_ids.reshape(-1) |
| loss = loss_fn(gen_logits_flat, target_ids_flat) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| total_loss += loss.item() |
| if batch_idx % 100 == 0: |
| print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}") |
| |
| avg_loss = total_loss / len(dataloader) |
| print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}") |
|
|
| |
| print("Saving model...") |
| torch.save(model.state_dict(), 'atis_transformer.pt') |
| print("Training completed!") |