feat: add training data collection for Rosie
Personality Dataset (300+ examples): - Greetings and farewells - Emotions and reactions - Physical interactions (pats, drags, touches) - Questions and answers - Help and support - Jokes and entertainment - Mood-based responses - Conversation fillers - Various user intents Data Download Script: - Download Project Gutenberg books (public domain) - Instructions for OpenWebText (~8B tokens) - Instructions for The Pile (~300B tokens) - Automatic dataset combination - Token counting and statistics - Download progress bars Ready to train: 1. Run: python scripts/download_training_data.py --all 2. Download additional datasets as needed 3. Run: python train_rosie.py --data_path data/combined_training.json 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
251
scripts/download_training_data.py
Normal file
251
scripts/download_training_data.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Download Training Data Script
|
||||
Downloads public domain datasets for training Rosie's base language model
|
||||
"""
|
||||
import os
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def download_file(url: str, filepath: str, description: str = ""):
|
||||
"""Download a file with progress bar"""
|
||||
print(f"Downloading {description}...")
|
||||
response = requests.get(url, stream=True)
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
|
||||
with open(filepath, 'wb') as f, tqdm(
|
||||
desc=description,
|
||||
total=total_size,
|
||||
unit='iB',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as pbar:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
size = f.write(chunk)
|
||||
pbar.update(size)
|
||||
|
||||
print(f"✓ Downloaded to {filepath}\n")
|
||||
|
||||
|
||||
def download_openwebtext_sample():
|
||||
"""Download a sample of OpenWebText dataset"""
|
||||
print("=" * 60)
|
||||
print("OpenWebText Sample")
|
||||
print("=" * 60)
|
||||
print("OpenWebText is a large web-scraped dataset (~40GB)")
|
||||
print("We'll download a small sample for initial training\n")
|
||||
|
||||
# Note: You'll need to download the full dataset from:
|
||||
# https://skylion007.github.io/OpenWebTextCorpus/
|
||||
print("To get the full OpenWebText dataset:")
|
||||
print("1. Visit: https://skylion007.github.io/OpenWebTextCorpus/")
|
||||
print("2. Download the .xz files")
|
||||
print("3. Extract to data/openwebtext/\n")
|
||||
|
||||
# For now, we'll create a placeholder
|
||||
os.makedirs('data/openwebtext', exist_ok=True)
|
||||
print("✓ Created data/openwebtext/ directory")
|
||||
print(" Please download OpenWebText files here\n")
|
||||
|
||||
|
||||
def download_gutenberg_books():
|
||||
"""Download sample books from Project Gutenberg"""
|
||||
print("=" * 60)
|
||||
print("Project Gutenberg Books")
|
||||
print("=" * 60)
|
||||
print("Downloading public domain books for language training\n")
|
||||
|
||||
os.makedirs('data/books', exist_ok=True)
|
||||
|
||||
# Sample books (all public domain)
|
||||
books = [
|
||||
{
|
||||
'url': 'https://www.gutenberg.org/files/1342/1342-0.txt',
|
||||
'name': 'Pride and Prejudice',
|
||||
'file': 'pride_and_prejudice.txt'
|
||||
},
|
||||
{
|
||||
'url': 'https://www.gutenberg.org/files/11/11-0.txt',
|
||||
'name': 'Alice in Wonderland',
|
||||
'file': 'alice_in_wonderland.txt'
|
||||
},
|
||||
{
|
||||
'url': 'https://www.gutenberg.org/files/84/84-0.txt',
|
||||
'name': 'Frankenstein',
|
||||
'file': 'frankenstein.txt'
|
||||
},
|
||||
{
|
||||
'url': 'https://www.gutenberg.org/files/1661/1661-0.txt',
|
||||
'name': 'Sherlock Holmes',
|
||||
'file': 'sherlock_holmes.txt'
|
||||
},
|
||||
{
|
||||
'url': 'https://www.gutenberg.org/files/2701/2701-0.txt',
|
||||
'name': 'Moby Dick',
|
||||
'file': 'moby_dick.txt'
|
||||
},
|
||||
]
|
||||
|
||||
for book in books:
|
||||
filepath = f"data/books/{book['file']}"
|
||||
if os.path.exists(filepath):
|
||||
print(f"✓ {book['name']} already downloaded")
|
||||
continue
|
||||
|
||||
try:
|
||||
download_file(book['url'], filepath, book['name'])
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to download {book['name']}: {e}\n")
|
||||
|
||||
print("✓ Books downloaded\n")
|
||||
|
||||
|
||||
def create_combined_dataset():
|
||||
"""Combine all downloaded data into training format"""
|
||||
print("=" * 60)
|
||||
print("Creating Combined Dataset")
|
||||
print("=" * 60)
|
||||
|
||||
texts = []
|
||||
|
||||
# Load books
|
||||
books_dir = Path('data/books')
|
||||
if books_dir.exists():
|
||||
print("Processing books...")
|
||||
for book_file in books_dir.glob('*.txt'):
|
||||
try:
|
||||
with open(book_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Split into paragraphs
|
||||
paragraphs = [p.strip() for p in content.split('\n\n') if len(p.strip()) > 100]
|
||||
texts.extend(paragraphs)
|
||||
print(f" ✓ {book_file.name}: {len(paragraphs)} paragraphs")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ✗ Error reading {book_file.name}: {e}")
|
||||
|
||||
# Load personality data
|
||||
personality_files = ['data/personality_base.json']
|
||||
for pfile in personality_files:
|
||||
if os.path.exists(pfile):
|
||||
print(f"Loading {pfile}...")
|
||||
with open(pfile, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
texts.extend(data['texts'])
|
||||
print(f" ✓ {len(data['texts'])} personality examples")
|
||||
|
||||
print(f"\nTotal texts collected: {len(texts)}")
|
||||
|
||||
# Save combined dataset
|
||||
output_file = 'data/combined_training.json'
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({'texts': texts}, f, indent=2)
|
||||
|
||||
print(f"✓ Saved to {output_file}\n")
|
||||
|
||||
# Calculate approximate token count (rough estimate: 1 token ≈ 4 characters)
|
||||
total_chars = sum(len(text) for text in texts)
|
||||
approx_tokens = total_chars // 4
|
||||
print(f"Approximate tokens: {approx_tokens:,} ({approx_tokens/1e6:.1f}M)")
|
||||
print(f"This is a SMALL dataset. For full training, you'll need 10-50B tokens.")
|
||||
print(f"Consider downloading OpenWebText or The Pile for complete training.\n")
|
||||
|
||||
|
||||
def show_dataset_info():
|
||||
"""Show information about available datasets"""
|
||||
print("\n" + "=" * 60)
|
||||
print("Available Public Datasets for Training")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
datasets = [
|
||||
{
|
||||
'name': 'OpenWebText',
|
||||
'size': '~40GB (38GB compressed)',
|
||||
'tokens': '~8B tokens',
|
||||
'url': 'https://skylion007.github.io/OpenWebTextCorpus/',
|
||||
'description': 'Web-scraped text from Reddit links'
|
||||
},
|
||||
{
|
||||
'name': 'The Pile',
|
||||
'size': '~800GB',
|
||||
'tokens': '~300B tokens',
|
||||
'url': 'https://pile.eleuther.ai/',
|
||||
'description': 'Massive diverse text dataset'
|
||||
},
|
||||
{
|
||||
'name': 'BookCorpus',
|
||||
'size': '~5GB',
|
||||
'tokens': '~1B tokens',
|
||||
'url': 'HuggingFace: bookcorpus',
|
||||
'description': 'Books corpus (11K books)'
|
||||
},
|
||||
{
|
||||
'name': 'Wikipedia',
|
||||
'size': '~20GB',
|
||||
'tokens': '~3B tokens',
|
||||
'url': 'https://dumps.wikimedia.org/',
|
||||
'description': 'Wikipedia dumps (all languages)'
|
||||
},
|
||||
{
|
||||
'name': 'Project Gutenberg',
|
||||
'size': '~10GB',
|
||||
'tokens': '~2B tokens',
|
||||
'url': 'https://www.gutenberg.org/',
|
||||
'description': 'Public domain books (60K+ books)'
|
||||
},
|
||||
]
|
||||
|
||||
for dataset in datasets:
|
||||
print(f"[*] {dataset['name']}")
|
||||
print(f" Size: {dataset['size']}")
|
||||
print(f" Tokens: {dataset['tokens']}")
|
||||
print(f" URL: {dataset['url']}")
|
||||
print(f" Description: {dataset['description']}")
|
||||
print()
|
||||
|
||||
print("Recommendation for Rosie training:")
|
||||
print(" - Start: Books + Personality data (~500M tokens)")
|
||||
print(" - Better: + OpenWebText (~8B tokens)")
|
||||
print(" - Best: + The Pile subset (~50B tokens)")
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Download training data for Rosie")
|
||||
parser.add_argument('--books', action='store_true', help='Download sample books')
|
||||
parser.add_argument('--info', action='store_true', help='Show dataset information')
|
||||
parser.add_argument('--combine', action='store_true', help='Combine downloaded data')
|
||||
parser.add_argument('--all', action='store_true', help='Download all available samples')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create data directory
|
||||
os.makedirs('data', exist_ok=True)
|
||||
|
||||
if args.info or (not any([args.books, args.combine, args.all])):
|
||||
show_dataset_info()
|
||||
|
||||
if args.books or args.all:
|
||||
download_gutenberg_books()
|
||||
download_openwebtext_sample()
|
||||
|
||||
if args.combine or args.all:
|
||||
create_combined_dataset()
|
||||
|
||||
print("=" * 60)
|
||||
print("Next Steps:")
|
||||
print("=" * 60)
|
||||
print("1. Download more data (see --info for sources)")
|
||||
print("2. Run: python train_rosie.py --data_path data/combined_training.json")
|
||||
print("3. Monitor training progress")
|
||||
print("4. Test the model with test_rosie.py")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user