Advanced Reward Functions

This guide covers advanced patterns and techniques for creating sophisticated reward functions.

Overview

Advanced reward functions go beyond simple accuracy checks to provide nuanced evaluation that considers multiple factors, context, and domain-specific requirements.

Multi-Metric Evaluation

Combine multiple evaluation criteria:

from reward_kit import reward_function
from reward_kit.rewards import accuracy, length, format_compliance
import numpy as np

@reward_function
def multi_metric_reward(response: str, expected_response: str, **kwargs) -> float:
    """
    Advanced reward function combining accuracy, length, and format compliance.
    """
    # Base accuracy score
    acc_score = accuracy(response, expected_response)

    # Length appropriateness (prefer responses between 50-200 chars)
    len_score = length_appropriateness(response, min_len=50, max_len=200)

    # Format compliance (if response should follow a pattern)
    format_score = check_format_compliance(response)

    # Weighted combination
    weights = [0.6, 0.2, 0.2]  # accuracy, length, format
    scores = [acc_score, len_score, format_score]

    return np.average(scores, weights=weights)

def length_appropriateness(response: str, min_len: int, max_len: int) -> float:
    """Helper function to score length appropriateness."""
    length = len(response)
    if min_len <= length <= max_len:
        return 1.0
    elif length < min_len:
        return max(0.0, length / min_len)
    else:
        return max(0.0, 1.0 - (length - max_len) / max_len)

def check_format_compliance(response: str) -> float:
    """Helper function to check format compliance."""
    # Example: Check if response follows expected structure
    if response.startswith("Answer:") and response.endswith("."):
        return 1.0
    return 0.5

Context-Aware Evaluation

Consider context when evaluating responses:

@reward_function
def context_aware_reward(response: str, expected_response: str, context: dict = None) -> float:
    """
    Reward function that considers context information.
    """
    base_score = accuracy(response, expected_response)

    if context:
        # Adjust score based on difficulty
        difficulty = context.get('difficulty', 'medium')
        if difficulty == 'hard' and base_score > 0.8:
            base_score *= 1.2  # Bonus for hard questions
        elif difficulty == 'easy' and base_score < 0.5:
            base_score *= 0.8  # Penalty for easy questions

        # Consider response time if available
        response_time = context.get('response_time_seconds', 0)
        if response_time > 0:
            # Slight bonus for quick accurate responses
            time_bonus = max(0, (10 - response_time) / 100)
            base_score += time_bonus

    return min(base_score, 1.0)

Domain-Specific Evaluation

Create reward functions tailored to specific domains:

@reward_function
def code_quality_reward(response: str, expected_response: str, **kwargs) -> float:
    """
    Evaluates code responses considering multiple quality factors.
    """
    import ast

    score = 0.0

    # Check if code is syntactically valid
    try:
        ast.parse(response)
        score += 0.3  # Syntax correctness
    except SyntaxError:
        return 0.0  # Invalid syntax gets zero score

    # Check for best practices
    if "def " in response:  # Function definition
        score += 0.2

    if "# " in response or '"""' in response:  # Comments/docstrings
        score += 0.1

    # Check for specific patterns
    if "import " in response and "from " in response:
        score += 0.1  # Good import practices

    # Length consideration (not too short, not too long)
    lines = response.split('\n')
    if 5 <= len(lines) <= 50:
        score += 0.1

    # Functional correctness (if test cases available)
    test_cases = kwargs.get('test_cases', [])
    if test_cases:
        correctness_score = evaluate_code_correctness(response, test_cases)
        score += 0.2 * correctness_score

    return min(score, 1.0)

def evaluate_code_correctness(code: str, test_cases: list) -> float:
    """Helper to evaluate code correctness against test cases."""
    # This would implement actual code execution and testing
    # For safety, this is a placeholder
    return 0.8  # Placeholder score

Statistical Evaluation

Use statistical methods for more robust evaluation:

from scipy.stats import pearsonr
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

@reward_function
def statistical_similarity_reward(response: str, expected_response: str, **kwargs) -> float:
    """
    Uses statistical methods to evaluate response similarity.
    """
    # Convert to numerical representations (e.g., using embeddings)
    response_embedding = get_text_embedding(response)
    expected_embedding = get_text_embedding(expected_response)

    # Cosine similarity
    cos_sim = cosine_similarity([response_embedding], [expected_embedding])[0][0]

    # Pearson correlation (if applicable)
    if len(response_embedding) == len(expected_embedding):
        corr, _ = pearsonr(response_embedding, expected_embedding)
        corr = max(0, corr)  # Only positive correlations
    else:
        corr = 0

    # Combine metrics
    final_score = 0.7 * cos_sim + 0.3 * corr

    return max(0.0, min(1.0, final_score))

def get_text_embedding(text: str) -> np.ndarray:
    """Placeholder for text embedding function."""
    # In practice, use a real embedding model
    return np.random.rand(100)  # Placeholder

Hierarchical Evaluation

Create reward functions with hierarchical evaluation:

@reward_function
def hierarchical_reward(response: str, expected_response: str, **kwargs) -> float:
    """
    Hierarchical evaluation with multiple levels of assessment.
    """
    # Level 1: Basic format validation
    if not basic_format_check(response):
        return 0.0

    # Level 2: Content relevance
    relevance_score = content_relevance(response, expected_response)
    if relevance_score < 0.3:
        return relevance_score * 0.5  # Cap low relevance scores

    # Level 3: Detailed accuracy
    accuracy_score = detailed_accuracy(response, expected_response)

    # Level 4: Style and presentation
    style_score = evaluate_style(response)

    # Weighted combination based on hierarchy
    final_score = (
        0.1 * 1.0 +  # Format passed
        0.3 * relevance_score +
        0.5 * accuracy_score +
        0.1 * style_score
    )

    return final_score

def basic_format_check(response: str) -> bool:
    """Basic format validation."""
    return len(response.strip()) > 0 and len(response) < 10000

def content_relevance(response: str, expected: str) -> float:
    """Evaluate content relevance."""
    # Placeholder for semantic similarity
    common_words = set(response.lower().split()) & set(expected.lower().split())
    return len(common_words) / max(len(set(expected.lower().split())), 1)

def detailed_accuracy(response: str, expected: str) -> float:
    """Detailed accuracy evaluation."""
    return accuracy(response, expected)

def evaluate_style(response: str) -> float:
    """Evaluate writing style and presentation."""
    score = 0.0
    if response[0].isupper():  # Starts with capital
        score += 0.3
    if response.endswith('.'):  # Ends with period
        score += 0.3
    if 10 <= len(response.split()) <= 100:  # Appropriate length
        score += 0.4
    return score

Best Practices for Advanced Rewards

  1. Modularity: Break complex evaluation into smaller, testable functions
  2. Robustness: Handle edge cases and invalid inputs gracefully
  3. Transparency: Make evaluation criteria clear and interpretable
  4. Validation: Test reward functions on diverse examples
  5. Performance: Consider computational efficiency for large-scale evaluation

Next Steps