添加链接
link管理
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

Repository files navigation

TRL - Transformer Reinforcement Learning

Overview

TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the 🤗 Transformers ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.

Highlights

Efficient and scalable :

  • Leverages 🤗 Accelerate to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
  • Full integration with PEFT enables training on large models with modest hardware via quantization and LoRA/QLoRA.
  • Integrates Unsloth for accelerating training using optimized kernels.
  • Command Line Interface (CLI) : A simple interface lets you fine-tune and interact with models without needing to write code.

    Trainers : Various fine-tuning methods are easily accessible via trainers like SFTTrainer , DPOTrainer , RewardTrainer , ORPOTrainer and more.

    AutoModels : Use pre-defined model classes like AutoModelForCausalLMWithValueHead to simplify reinforcement learning (RL) with LLMs.

    Installation

    Python Package

    Install the library using pip :

    pip install trl

    From source

    If you want to use the latest features before an official release, you can install TRL from source:

    pip install git+https://github.com/huggingface/trl.git

    Repository

    If you want to use the examples you can clone the repository with the following command:

    git clone https://github.com/huggingface/trl.git

    Command Line Interface (CLI)

    You can use the TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI:

    trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
        --dataset_name trl-lib/Capybara \
        --output_dir Qwen2.5-0.5B-SFT
    trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
        --dataset_name argilla/Capybara-Preferences \
        --output_dir Qwen2.5-0.5B-DPO 

    Chat:

    trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct

    Read more about CLI in the relevant documentation section or use --help for more details.

    How to use

    For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.

    SFTTrainer

    Here is a basic example of how to use the SFTTrainer :

    from trl import SFTConfig, SFTTrainer
    from datasets import load_dataset
    dataset = load_dataset("trl-lib/Capybara", split="train")
    training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
    trainer = SFTTrainer(
        args=training_args,
        model="Qwen/Qwen2.5-0.5B",
        train_dataset=dataset,
    trainer.train()

    RewardTrainer

    Here is a basic example of how to use the RewardTrainer :

    from trl import RewardConfig, RewardTrainer
    from datasets import load_dataset
    from transformers import AutoModelForSequenceClassification, AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
    model = AutoModelForSequenceClassification.from_pretrained(
        "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
    model.config.pad_token_id = tokenizer.pad_token_id
    dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
    training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
    trainer = RewardTrainer(
        args=training_args,
        model=model,
        processing_class=tokenizer,
        train_dataset=dataset,
    trainer.train()

    RLOOTrainer

    RLOOTrainer implements a REINFORCE-style optimization for RLHF that is more performant and memory-efficient than PPO. Here is a basic example of how to use the RLOOTrainer :

    from trl import RLOOConfig, RLOOTrainer, apply_chat_template
    from datasets import load_dataset
    from transformers import (
        AutoModelForCausalLM,
        AutoModelForSequenceClassification,
        AutoTokenizer,
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
    reward_model = AutoModelForSequenceClassification.from_pretrained(
        "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
    ref_policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
    policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
    dataset = load_dataset("trl-lib/ultrafeedback-prompt")
    dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
    dataset = dataset.map(lambda x: tokenizer(x["prompt"]), remove_columns="prompt")
    training_args = RLOOConfig(output_dir="Qwen2.5-0.5B-RL")
    trainer = RLOOTrainer(
        config=training_args,
        processing_class=tokenizer,
        policy=policy,
        ref_policy=ref_policy,
        reward_model=reward_model,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
    trainer.train()

    DPOTrainer

    DPOTrainer implements the popular Direct Preference Optimization (DPO) algorithm that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the DPOTrainer :

    from datasets import load_dataset
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from trl import DPOConfig, DPOTrainer
    model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
    dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
    training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
    trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
    trainer.train()

    Development

    If you want to contribute to trl or customize it to your needs make sure to read the contribution guide and make sure you make a dev install:

    git clone https://github.com/huggingface/trl.git
    cd trl/
    make dev

    Citation

    @misc{vonwerra2022trl,
      author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
      title = {TRL: Transformer Reinforcement Learning},
      year = {2020},
      publisher = {GitHub},
      journal = {GitHub repository},
      howpublished = {\url{https://github.com/huggingface/trl}}
    

    License

    This repository's source code is available under the Apache-2.0 License.