Evaluators
- Getting Started
- Developer Guide
- Evaluation Examples
- API Reference
- CLI Reference
Advanced reward functions
Advanced Reward Functions
This guide covers advanced patterns and techniques for creating sophisticated reward functions.
Overview
Advanced reward functions go beyond simple accuracy checks to provide nuanced evaluation that considers multiple factors, context, and domain-specific requirements.
Multi-Metric Evaluation
Combine multiple evaluation criteria:
Copy
from reward_kit import reward_function
from reward_kit.rewards import accuracy, length, format_compliance
import numpy as np
@reward_function
def multi_metric_reward(response: str, expected_response: str, **kwargs) -> float:
"""
Advanced reward function combining accuracy, length, and format compliance.
"""
# Base accuracy score
acc_score = accuracy(response, expected_response)
# Length appropriateness (prefer responses between 50-200 chars)
len_score = length_appropriateness(response, min_len=50, max_len=200)
# Format compliance (if response should follow a pattern)
format_score = check_format_compliance(response)
# Weighted combination
weights = [0.6, 0.2, 0.2] # accuracy, length, format
scores = [acc_score, len_score, format_score]
return np.average(scores, weights=weights)
def length_appropriateness(response: str, min_len: int, max_len: int) -> float:
"""Helper function to score length appropriateness."""
length = len(response)
if min_len <= length <= max_len:
return 1.0
elif length < min_len:
return max(0.0, length / min_len)
else:
return max(0.0, 1.0 - (length - max_len) / max_len)
def check_format_compliance(response: str) -> float:
"""Helper function to check format compliance."""
# Example: Check if response follows expected structure
if response.startswith("Answer:") and response.endswith("."):
return 1.0
return 0.5
Context-Aware Evaluation
Consider context when evaluating responses:
Copy
@reward_function
def context_aware_reward(response: str, expected_response: str, context: dict = None) -> float:
"""
Reward function that considers context information.
"""
base_score = accuracy(response, expected_response)
if context:
# Adjust score based on difficulty
difficulty = context.get('difficulty', 'medium')
if difficulty == 'hard' and base_score > 0.8:
base_score *= 1.2 # Bonus for hard questions
elif difficulty == 'easy' and base_score < 0.5:
base_score *= 0.8 # Penalty for easy questions
# Consider response time if available
response_time = context.get('response_time_seconds', 0)
if response_time > 0:
# Slight bonus for quick accurate responses
time_bonus = max(0, (10 - response_time) / 100)
base_score += time_bonus
return min(base_score, 1.0)
Domain-Specific Evaluation
Create reward functions tailored to specific domains:
Copy
@reward_function
def code_quality_reward(response: str, expected_response: str, **kwargs) -> float:
"""
Evaluates code responses considering multiple quality factors.
"""
import ast
score = 0.0
# Check if code is syntactically valid
try:
ast.parse(response)
score += 0.3 # Syntax correctness
except SyntaxError:
return 0.0 # Invalid syntax gets zero score
# Check for best practices
if "def " in response: # Function definition
score += 0.2
if "# " in response or '"""' in response: # Comments/docstrings
score += 0.1
# Check for specific patterns
if "import " in response and "from " in response:
score += 0.1 # Good import practices
# Length consideration (not too short, not too long)
lines = response.split('\n')
if 5 <= len(lines) <= 50:
score += 0.1
# Functional correctness (if test cases available)
test_cases = kwargs.get('test_cases', [])
if test_cases:
correctness_score = evaluate_code_correctness(response, test_cases)
score += 0.2 * correctness_score
return min(score, 1.0)
def evaluate_code_correctness(code: str, test_cases: list) -> float:
"""Helper to evaluate code correctness against test cases."""
# This would implement actual code execution and testing
# For safety, this is a placeholder
return 0.8 # Placeholder score
Statistical Evaluation
Use statistical methods for more robust evaluation:
Copy
from scipy.stats import pearsonr
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
@reward_function
def statistical_similarity_reward(response: str, expected_response: str, **kwargs) -> float:
"""
Uses statistical methods to evaluate response similarity.
"""
# Convert to numerical representations (e.g., using embeddings)
response_embedding = get_text_embedding(response)
expected_embedding = get_text_embedding(expected_response)
# Cosine similarity
cos_sim = cosine_similarity([response_embedding], [expected_embedding])[0][0]
# Pearson correlation (if applicable)
if len(response_embedding) == len(expected_embedding):
corr, _ = pearsonr(response_embedding, expected_embedding)
corr = max(0, corr) # Only positive correlations
else:
corr = 0
# Combine metrics
final_score = 0.7 * cos_sim + 0.3 * corr
return max(0.0, min(1.0, final_score))
def get_text_embedding(text: str) -> np.ndarray:
"""Placeholder for text embedding function."""
# In practice, use a real embedding model
return np.random.rand(100) # Placeholder
Hierarchical Evaluation
Create reward functions with hierarchical evaluation:
Copy
@reward_function
def hierarchical_reward(response: str, expected_response: str, **kwargs) -> float:
"""
Hierarchical evaluation with multiple levels of assessment.
"""
# Level 1: Basic format validation
if not basic_format_check(response):
return 0.0
# Level 2: Content relevance
relevance_score = content_relevance(response, expected_response)
if relevance_score < 0.3:
return relevance_score * 0.5 # Cap low relevance scores
# Level 3: Detailed accuracy
accuracy_score = detailed_accuracy(response, expected_response)
# Level 4: Style and presentation
style_score = evaluate_style(response)
# Weighted combination based on hierarchy
final_score = (
0.1 * 1.0 + # Format passed
0.3 * relevance_score +
0.5 * accuracy_score +
0.1 * style_score
)
return final_score
def basic_format_check(response: str) -> bool:
"""Basic format validation."""
return len(response.strip()) > 0 and len(response) < 10000
def content_relevance(response: str, expected: str) -> float:
"""Evaluate content relevance."""
# Placeholder for semantic similarity
common_words = set(response.lower().split()) & set(expected.lower().split())
return len(common_words) / max(len(set(expected.lower().split())), 1)
def detailed_accuracy(response: str, expected: str) -> float:
"""Detailed accuracy evaluation."""
return accuracy(response, expected)
def evaluate_style(response: str) -> float:
"""Evaluate writing style and presentation."""
score = 0.0
if response[0].isupper(): # Starts with capital
score += 0.3
if response.endswith('.'): # Ends with period
score += 0.3
if 10 <= len(response.split()) <= 100: # Appropriate length
score += 0.4
return score
Best Practices for Advanced Rewards
- Modularity: Break complex evaluation into smaller, testable functions
- Robustness: Handle edge cases and invalid inputs gracefully
- Transparency: Make evaluation criteria clear and interpretable
- Validation: Test reward functions on diverse examples
- Performance: Consider computational efficiency for large-scale evaluation
Next Steps
Was this page helpful?
Assistant
Responses are generated using AI and may contain mistakes.