feat: Add Streamlit dashboard with Blueprint compliance (v2.1.0)
Dashboard Features: - 8 navigation sections: Overview, Outcomes, Poor CX, FCR, Churn, Agent, Call Explorer, Export - Beyond Brand Identity styling (colors #6D84E3, Outfit font) - RCA Sankey diagram (Driver → Outcome → Churn Risk flow) - Correlation heatmaps (driver co-occurrence, driver-outcome) - Outcome Deep Dive (root causes, correlation, duration analysis) - Export functionality (Excel, HTML, JSON) Blueprint Compliance: - FCR: 4 categories (Primera Llamada/Rellamada × Sin/Con Riesgo de Fuga) - Churn: Binary view (Sin Riesgo de Fuga / En Riesgo de Fuga) - Agent: Talento Para Replicar / Oportunidades de Mejora - Fixed FCR rate calculation (only FIRST_CALL counts as success) Technical: - Streamlit + Plotly for interactive visualizations - Light theme configuration (.streamlit/config.toml) - Fixed Plotly colorbar titlefont deprecation Documentation: - Updated PROJECT_CONTEXT.md, TODO.md, CHANGELOG.md - Added 4 new technical decisions (TD-014 to TD-017) - Created TROUBLESHOOTING.md with 10 common issues Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
53
tests/conftest.py
Normal file
53
tests/conftest.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
CXInsights - Pytest Configuration and Fixtures
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Set test environment
|
||||
os.environ["TESTING"] = "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def project_root() -> Path:
|
||||
"""Return the project root directory."""
|
||||
return Path(__file__).parent.parent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixtures_dir(project_root: Path) -> Path:
|
||||
"""Return the fixtures directory."""
|
||||
return project_root / "tests" / "fixtures"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_audio_dir(fixtures_dir: Path) -> Path:
|
||||
"""Return the sample audio directory."""
|
||||
return fixtures_dir / "sample_audio"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_transcripts_dir(fixtures_dir: Path) -> Path:
|
||||
"""Return the sample transcripts directory."""
|
||||
return fixtures_dir / "sample_transcripts"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_dir(project_root: Path) -> Path:
|
||||
"""Return the config directory."""
|
||||
return project_root / "config"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def taxonomy_path(config_dir: Path) -> Path:
|
||||
"""Return the RCA taxonomy file path."""
|
||||
return config_dir / "rca_taxonomy.yaml"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings_path(config_dir: Path) -> Path:
|
||||
"""Return the settings file path."""
|
||||
return config_dir / "settings.yaml"
|
||||
0
tests/fixtures/expected_outputs/.gitkeep
vendored
Normal file
0
tests/fixtures/expected_outputs/.gitkeep
vendored
Normal file
0
tests/fixtures/sample_audio/.gitkeep
vendored
Normal file
0
tests/fixtures/sample_audio/.gitkeep
vendored
Normal file
0
tests/fixtures/sample_features/.gitkeep
vendored
Normal file
0
tests/fixtures/sample_features/.gitkeep
vendored
Normal file
0
tests/fixtures/sample_transcripts/compressed/.gitkeep
vendored
Normal file
0
tests/fixtures/sample_transcripts/compressed/.gitkeep
vendored
Normal file
0
tests/fixtures/sample_transcripts/raw/.gitkeep
vendored
Normal file
0
tests/fixtures/sample_transcripts/raw/.gitkeep
vendored
Normal file
0
tests/integration/__init__.py
Normal file
0
tests/integration/__init__.py
Normal file
0
tests/unit/__init__.py
Normal file
0
tests/unit/__init__.py
Normal file
582
tests/unit/test_aggregation.py
Normal file
582
tests/unit/test_aggregation.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""
|
||||
CXInsights - Aggregation Module Tests
|
||||
|
||||
Tests for statistics, severity scoring, and RCA tree building.
|
||||
v2.0: Updated with FCR, churn risk, and agent skill tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.aggregation import (
|
||||
AggregationConfig,
|
||||
BatchAggregation,
|
||||
DriverFrequency,
|
||||
DriverSeverity,
|
||||
ImpactLevel,
|
||||
RCANode,
|
||||
RCATree,
|
||||
RCATreeBuilder,
|
||||
SeverityCalculator,
|
||||
StatisticsCalculator,
|
||||
aggregate_batch,
|
||||
build_rca_tree,
|
||||
calculate_batch_statistics,
|
||||
calculate_driver_severities,
|
||||
)
|
||||
from src.models.call_analysis import (
|
||||
AgentClassification,
|
||||
AgentSkillIndicator,
|
||||
CallAnalysis,
|
||||
CallOutcome,
|
||||
ChurnRisk,
|
||||
EvidenceSpan,
|
||||
FCRStatus,
|
||||
ObservedFeatures,
|
||||
ProcessingStatus,
|
||||
RCALabel,
|
||||
Traceability,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_analyses():
|
||||
"""Create sample call analyses for testing (v2.0 with FCR, churn, agent)."""
|
||||
base_observed = ObservedFeatures(
|
||||
audio_duration_sec=60.0,
|
||||
events=[],
|
||||
)
|
||||
base_traceability = Traceability(
|
||||
schema_version="1.0.0",
|
||||
prompt_version="v2.0",
|
||||
model_id="gpt-4o-mini",
|
||||
)
|
||||
|
||||
analyses = []
|
||||
|
||||
# Analysis 1: Lost sale due to price, first call, at risk
|
||||
analyses.append(CallAnalysis(
|
||||
call_id="CALL001",
|
||||
batch_id="test_batch",
|
||||
status=ProcessingStatus.SUCCESS,
|
||||
observed=base_observed,
|
||||
outcome=CallOutcome.SALE_LOST,
|
||||
lost_sales_drivers=[
|
||||
RCALabel(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
confidence=0.9,
|
||||
evidence_spans=[EvidenceSpan(text="Es muy caro", start_time=10, end_time=12)],
|
||||
),
|
||||
],
|
||||
poor_cx_drivers=[],
|
||||
fcr_status=FCRStatus.FIRST_CALL,
|
||||
churn_risk=ChurnRisk.AT_RISK,
|
||||
churn_risk_drivers=[
|
||||
RCALabel(
|
||||
driver_code="COMPETITOR_MENTION",
|
||||
confidence=0.85,
|
||||
evidence_spans=[EvidenceSpan(text="Vodafone me ofrece", start_time=20, end_time=22)],
|
||||
),
|
||||
],
|
||||
agent_classification=AgentClassification.NEEDS_IMPROVEMENT,
|
||||
traceability=base_traceability,
|
||||
))
|
||||
|
||||
# Analysis 2: Lost sale due to price + competitor, repeat call
|
||||
analyses.append(CallAnalysis(
|
||||
call_id="CALL002",
|
||||
batch_id="test_batch",
|
||||
status=ProcessingStatus.SUCCESS,
|
||||
observed=base_observed,
|
||||
outcome=CallOutcome.SALE_LOST,
|
||||
lost_sales_drivers=[
|
||||
RCALabel(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
confidence=0.85,
|
||||
evidence_spans=[EvidenceSpan(text="Muy caro", start_time=15, end_time=17)],
|
||||
),
|
||||
RCALabel(
|
||||
driver_code="COMPETITOR_PREFERENCE",
|
||||
confidence=0.8,
|
||||
evidence_spans=[EvidenceSpan(text="La competencia ofrece mejor", start_time=20, end_time=23)],
|
||||
),
|
||||
],
|
||||
poor_cx_drivers=[],
|
||||
fcr_status=FCRStatus.REPEAT_CALL,
|
||||
fcr_failure_drivers=[
|
||||
RCALabel(
|
||||
driver_code="INCOMPLETE_RESOLUTION",
|
||||
confidence=0.8,
|
||||
evidence_spans=[EvidenceSpan(text="Ya llamé antes", start_time=5, end_time=7)],
|
||||
),
|
||||
],
|
||||
churn_risk=ChurnRisk.AT_RISK,
|
||||
agent_classification=AgentClassification.MIXED,
|
||||
traceability=base_traceability,
|
||||
))
|
||||
|
||||
# Analysis 3: Poor CX - long hold, first call, good agent
|
||||
analyses.append(CallAnalysis(
|
||||
call_id="CALL003",
|
||||
batch_id="test_batch",
|
||||
status=ProcessingStatus.SUCCESS,
|
||||
observed=base_observed,
|
||||
outcome=CallOutcome.INQUIRY_RESOLVED,
|
||||
lost_sales_drivers=[],
|
||||
poor_cx_drivers=[
|
||||
RCALabel(
|
||||
driver_code="LONG_HOLD",
|
||||
confidence=0.95,
|
||||
evidence_spans=[EvidenceSpan(text="Mucho tiempo esperando", start_time=5, end_time=8)],
|
||||
),
|
||||
],
|
||||
fcr_status=FCRStatus.FIRST_CALL,
|
||||
churn_risk=ChurnRisk.NO_RISK,
|
||||
agent_classification=AgentClassification.GOOD_PERFORMER,
|
||||
agent_positive_skills=[
|
||||
AgentSkillIndicator(
|
||||
skill_code="EMPATHY_SHOWN",
|
||||
skill_type="positive",
|
||||
confidence=0.9,
|
||||
evidence_spans=[EvidenceSpan(text="Entiendo su frustración", start_time=10, end_time=12)],
|
||||
description="Agent showed empathy",
|
||||
),
|
||||
],
|
||||
traceability=base_traceability,
|
||||
))
|
||||
|
||||
# Analysis 4: Both lost sale and poor CX, repeat call
|
||||
analyses.append(CallAnalysis(
|
||||
call_id="CALL004",
|
||||
batch_id="test_batch",
|
||||
status=ProcessingStatus.SUCCESS,
|
||||
observed=base_observed,
|
||||
outcome=CallOutcome.SALE_LOST,
|
||||
lost_sales_drivers=[
|
||||
RCALabel(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
confidence=0.75,
|
||||
evidence_spans=[EvidenceSpan(text="No puedo pagar", start_time=30, end_time=32)],
|
||||
),
|
||||
],
|
||||
poor_cx_drivers=[
|
||||
RCALabel(
|
||||
driver_code="LOW_EMPATHY",
|
||||
confidence=0.7,
|
||||
evidence_spans=[EvidenceSpan(text="No me escucha", start_time=25, end_time=27)],
|
||||
),
|
||||
],
|
||||
fcr_status=FCRStatus.REPEAT_CALL,
|
||||
churn_risk=ChurnRisk.AT_RISK,
|
||||
agent_classification=AgentClassification.NEEDS_IMPROVEMENT,
|
||||
agent_improvement_areas=[
|
||||
AgentSkillIndicator(
|
||||
skill_code="POOR_CLOSING",
|
||||
skill_type="improvement_needed",
|
||||
confidence=0.8,
|
||||
evidence_spans=[EvidenceSpan(text="Bueno, pues llame otro día", start_time=50, end_time=53)],
|
||||
description="Agent failed to close",
|
||||
),
|
||||
],
|
||||
traceability=base_traceability,
|
||||
))
|
||||
|
||||
# Analysis 5: Successful sale (no issues), first call, good agent
|
||||
analyses.append(CallAnalysis(
|
||||
call_id="CALL005",
|
||||
batch_id="test_batch",
|
||||
status=ProcessingStatus.SUCCESS,
|
||||
observed=base_observed,
|
||||
outcome=CallOutcome.SALE_COMPLETED,
|
||||
lost_sales_drivers=[],
|
||||
poor_cx_drivers=[],
|
||||
fcr_status=FCRStatus.FIRST_CALL,
|
||||
churn_risk=ChurnRisk.NO_RISK,
|
||||
agent_classification=AgentClassification.GOOD_PERFORMER,
|
||||
traceability=base_traceability,
|
||||
))
|
||||
|
||||
return analyses
|
||||
|
||||
|
||||
class TestDriverFrequency:
|
||||
"""Tests for DriverFrequency model."""
|
||||
|
||||
def test_valid_frequency(self):
|
||||
"""Test valid frequency creation."""
|
||||
freq = DriverFrequency(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
category="lost_sales",
|
||||
total_occurrences=3,
|
||||
calls_affected=3,
|
||||
total_calls_in_batch=5,
|
||||
occurrence_rate=0.6,
|
||||
call_rate=0.6,
|
||||
avg_confidence=0.83,
|
||||
min_confidence=0.75,
|
||||
max_confidence=0.9,
|
||||
)
|
||||
|
||||
assert freq.driver_code == "PRICE_TOO_HIGH"
|
||||
assert freq.occurrence_rate == 0.6
|
||||
|
||||
def test_invalid_rate(self):
|
||||
"""Test that invalid rates raise error."""
|
||||
with pytest.raises(ValueError):
|
||||
DriverFrequency(
|
||||
driver_code="TEST",
|
||||
category="lost_sales",
|
||||
total_occurrences=1,
|
||||
calls_affected=1,
|
||||
total_calls_in_batch=5,
|
||||
occurrence_rate=1.5, # Invalid!
|
||||
call_rate=0.2,
|
||||
avg_confidence=0.8,
|
||||
min_confidence=0.8,
|
||||
max_confidence=0.8,
|
||||
)
|
||||
|
||||
|
||||
class TestDriverSeverity:
|
||||
"""Tests for DriverSeverity model."""
|
||||
|
||||
def test_valid_severity(self):
|
||||
"""Test valid severity creation."""
|
||||
sev = DriverSeverity(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
category="lost_sales",
|
||||
base_severity=0.8,
|
||||
frequency_factor=0.6,
|
||||
confidence_factor=0.85,
|
||||
co_occurrence_factor=0.3,
|
||||
severity_score=65.0,
|
||||
impact_level=ImpactLevel.HIGH,
|
||||
)
|
||||
|
||||
assert sev.severity_score == 65.0
|
||||
assert sev.impact_level == ImpactLevel.HIGH
|
||||
|
||||
def test_invalid_severity_score(self):
|
||||
"""Test that invalid severity score raises error."""
|
||||
with pytest.raises(ValueError):
|
||||
DriverSeverity(
|
||||
driver_code="TEST",
|
||||
category="lost_sales",
|
||||
base_severity=0.5,
|
||||
frequency_factor=0.5,
|
||||
confidence_factor=0.5,
|
||||
co_occurrence_factor=0.5,
|
||||
severity_score=150.0, # Invalid!
|
||||
impact_level=ImpactLevel.HIGH,
|
||||
)
|
||||
|
||||
|
||||
class TestStatisticsCalculator:
|
||||
"""Tests for StatisticsCalculator."""
|
||||
|
||||
def test_calculate_frequencies(self, sample_analyses):
|
||||
"""Test frequency calculation (v2.0 dict format)."""
|
||||
calculator = StatisticsCalculator()
|
||||
frequencies = calculator.calculate_frequencies(sample_analyses)
|
||||
|
||||
# Check all categories are present
|
||||
assert "lost_sales" in frequencies
|
||||
assert "poor_cx" in frequencies
|
||||
assert "fcr_failure" in frequencies
|
||||
assert "churn_risk" in frequencies
|
||||
assert "agent_positive" in frequencies
|
||||
assert "agent_improvement" in frequencies
|
||||
|
||||
# PRICE_TOO_HIGH appears in 3 calls
|
||||
lost_sales = frequencies["lost_sales"]
|
||||
price_freq = next(f for f in lost_sales if f.driver_code == "PRICE_TOO_HIGH")
|
||||
assert price_freq.total_occurrences == 3
|
||||
assert price_freq.calls_affected == 3
|
||||
assert price_freq.call_rate == 0.6 # 3/5 calls
|
||||
|
||||
# FCR failure drivers
|
||||
fcr_failure = frequencies["fcr_failure"]
|
||||
assert len(fcr_failure) == 1 # INCOMPLETE_RESOLUTION
|
||||
|
||||
# Agent positive skills
|
||||
agent_positive = frequencies["agent_positive"]
|
||||
assert len(agent_positive) == 1 # EMPATHY_SHOWN
|
||||
|
||||
def test_calculate_outcome_rates(self, sample_analyses):
|
||||
"""Test outcome rate calculation with v2.0 metrics."""
|
||||
calculator = StatisticsCalculator()
|
||||
rates = calculator.calculate_outcome_rates(sample_analyses)
|
||||
|
||||
assert rates["total_calls"] == 5
|
||||
assert rates["lost_sales_count"] == 3 # Calls with lost sales drivers
|
||||
assert rates["poor_cx_count"] == 2 # Calls with poor CX drivers
|
||||
assert rates["both_count"] == 1 # Calls with both
|
||||
|
||||
# v2.0: FCR metrics
|
||||
assert rates["fcr"]["first_call"] == 3
|
||||
assert rates["fcr"]["repeat_call"] == 2
|
||||
assert rates["fcr"]["repeat_rate"] == 0.4 # 2/5
|
||||
|
||||
# v2.0: Churn metrics
|
||||
assert rates["churn"]["at_risk"] == 3
|
||||
assert rates["churn"]["no_risk"] == 2
|
||||
|
||||
# v2.0: Agent metrics
|
||||
assert rates["agent"]["good_performer"] == 2
|
||||
assert rates["agent"]["needs_improvement"] == 2
|
||||
assert rates["agent"]["mixed"] == 1
|
||||
|
||||
def test_empty_analyses(self):
|
||||
"""Test with empty analyses list."""
|
||||
calculator = StatisticsCalculator()
|
||||
frequencies = calculator.calculate_frequencies([])
|
||||
|
||||
assert frequencies["lost_sales"] == []
|
||||
assert frequencies["poor_cx"] == []
|
||||
assert frequencies["fcr_failure"] == []
|
||||
assert frequencies["churn_risk"] == []
|
||||
|
||||
def test_conditional_probabilities(self, sample_analyses):
|
||||
"""Test conditional probability calculation."""
|
||||
config = AggregationConfig(min_support=1) # Low threshold for test
|
||||
calculator = StatisticsCalculator(config=config)
|
||||
probs = calculator.calculate_conditional_probabilities(sample_analyses)
|
||||
|
||||
# Should find relationships between drivers
|
||||
assert len(probs) > 0
|
||||
|
||||
|
||||
class TestSeverityCalculator:
|
||||
"""Tests for SeverityCalculator."""
|
||||
|
||||
def test_get_base_severity(self):
|
||||
"""Test base severity lookup."""
|
||||
calculator = SeverityCalculator()
|
||||
|
||||
# From taxonomy
|
||||
assert calculator.get_base_severity("PRICE_TOO_HIGH", "lost_sales") == 0.8
|
||||
assert calculator.get_base_severity("RUDE_BEHAVIOR", "poor_cx") == 0.9
|
||||
|
||||
# Unknown driver
|
||||
assert calculator.get_base_severity("UNKNOWN", "lost_sales") == 0.5
|
||||
|
||||
def test_calculate_severity(self):
|
||||
"""Test severity calculation."""
|
||||
calculator = SeverityCalculator()
|
||||
|
||||
freq = DriverFrequency(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
category="lost_sales",
|
||||
total_occurrences=3,
|
||||
calls_affected=3,
|
||||
total_calls_in_batch=5,
|
||||
occurrence_rate=0.6,
|
||||
call_rate=0.6,
|
||||
avg_confidence=0.85,
|
||||
min_confidence=0.75,
|
||||
max_confidence=0.9,
|
||||
commonly_co_occurs_with=["COMPETITOR_PREFERENCE"],
|
||||
)
|
||||
|
||||
severity = calculator.calculate_severity(freq)
|
||||
|
||||
assert severity.driver_code == "PRICE_TOO_HIGH"
|
||||
assert severity.base_severity == 0.8
|
||||
assert 0 <= severity.severity_score <= 100
|
||||
assert severity.impact_level in [
|
||||
ImpactLevel.CRITICAL,
|
||||
ImpactLevel.HIGH,
|
||||
ImpactLevel.MEDIUM,
|
||||
ImpactLevel.LOW,
|
||||
]
|
||||
|
||||
def test_impact_level_thresholds(self):
|
||||
"""Test impact level determination."""
|
||||
calculator = SeverityCalculator()
|
||||
|
||||
# High severity + high frequency = CRITICAL
|
||||
high_freq = DriverFrequency(
|
||||
driver_code="TEST",
|
||||
category="lost_sales",
|
||||
total_occurrences=15,
|
||||
calls_affected=15,
|
||||
total_calls_in_batch=100,
|
||||
occurrence_rate=0.15,
|
||||
call_rate=0.15, # >10%
|
||||
avg_confidence=0.9,
|
||||
min_confidence=0.9,
|
||||
max_confidence=0.9,
|
||||
)
|
||||
|
||||
sev = calculator.calculate_severity(high_freq)
|
||||
# Should be HIGH or CRITICAL due to high frequency
|
||||
assert sev.impact_level in [ImpactLevel.CRITICAL, ImpactLevel.HIGH]
|
||||
|
||||
|
||||
class TestRCATreeBuilder:
|
||||
"""Tests for RCATreeBuilder."""
|
||||
|
||||
def test_build_tree(self, sample_analyses):
|
||||
"""Test RCA tree building."""
|
||||
builder = RCATreeBuilder()
|
||||
tree = builder.build("test_batch", sample_analyses)
|
||||
|
||||
assert tree.batch_id == "test_batch"
|
||||
assert tree.total_calls == 5
|
||||
assert len(tree.lost_sales_root) > 0
|
||||
assert len(tree.poor_cx_root) > 0
|
||||
|
||||
def test_top_drivers(self, sample_analyses):
|
||||
"""Test top drivers extraction."""
|
||||
builder = RCATreeBuilder()
|
||||
tree = builder.build("test_batch", sample_analyses)
|
||||
|
||||
# PRICE_TOO_HIGH should be top driver
|
||||
assert "PRICE_TOO_HIGH" in tree.top_lost_sales_drivers
|
||||
|
||||
def test_tree_to_dict(self, sample_analyses):
|
||||
"""Test tree serialization."""
|
||||
builder = RCATreeBuilder()
|
||||
tree = builder.build("test_batch", sample_analyses)
|
||||
|
||||
tree_dict = tree.to_dict()
|
||||
|
||||
assert "batch_id" in tree_dict
|
||||
assert "summary" in tree_dict
|
||||
assert "lost_sales_tree" in tree_dict
|
||||
assert "poor_cx_tree" in tree_dict
|
||||
|
||||
def test_build_aggregation(self, sample_analyses):
|
||||
"""Test full aggregation building."""
|
||||
builder = RCATreeBuilder()
|
||||
agg = builder.build_aggregation("test_batch", sample_analyses)
|
||||
|
||||
assert isinstance(agg, BatchAggregation)
|
||||
assert agg.total_calls_processed == 5
|
||||
assert agg.successful_analyses == 5
|
||||
assert agg.rca_tree is not None
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""Tests for convenience functions."""
|
||||
|
||||
def test_calculate_batch_statistics(self, sample_analyses):
|
||||
"""Test calculate_batch_statistics function (v2.0 enhanced)."""
|
||||
stats = calculate_batch_statistics(sample_analyses)
|
||||
|
||||
# v1.0 keys
|
||||
assert "outcome_rates" in stats
|
||||
assert "lost_sales_frequencies" in stats
|
||||
assert "poor_cx_frequencies" in stats
|
||||
|
||||
# v2.0 keys
|
||||
assert "fcr_failure_frequencies" in stats
|
||||
assert "churn_risk_frequencies" in stats
|
||||
assert "agent_positive_frequencies" in stats
|
||||
assert "agent_improvement_frequencies" in stats
|
||||
|
||||
# v2.0 outcome_rates should have nested dicts
|
||||
assert "fcr" in stats["outcome_rates"]
|
||||
assert "churn" in stats["outcome_rates"]
|
||||
assert "agent" in stats["outcome_rates"]
|
||||
|
||||
def test_build_rca_tree_function(self, sample_analyses):
|
||||
"""Test build_rca_tree function."""
|
||||
tree = build_rca_tree("test_batch", sample_analyses)
|
||||
|
||||
assert isinstance(tree, RCATree)
|
||||
assert tree.batch_id == "test_batch"
|
||||
|
||||
def test_aggregate_batch_function(self, sample_analyses):
|
||||
"""Test aggregate_batch function."""
|
||||
agg = aggregate_batch("test_batch", sample_analyses)
|
||||
|
||||
assert isinstance(agg, BatchAggregation)
|
||||
assert agg.batch_id == "test_batch"
|
||||
|
||||
|
||||
class TestRCANode:
|
||||
"""Tests for RCANode model."""
|
||||
|
||||
def test_node_to_dict(self):
|
||||
"""Test node serialization."""
|
||||
freq = DriverFrequency(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
category="lost_sales",
|
||||
total_occurrences=3,
|
||||
calls_affected=3,
|
||||
total_calls_in_batch=5,
|
||||
occurrence_rate=0.6,
|
||||
call_rate=0.6,
|
||||
avg_confidence=0.85,
|
||||
min_confidence=0.75,
|
||||
max_confidence=0.9,
|
||||
)
|
||||
|
||||
sev = DriverSeverity(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
category="lost_sales",
|
||||
base_severity=0.8,
|
||||
frequency_factor=0.6,
|
||||
confidence_factor=0.85,
|
||||
co_occurrence_factor=0.3,
|
||||
severity_score=65.0,
|
||||
impact_level=ImpactLevel.HIGH,
|
||||
)
|
||||
|
||||
node = RCANode(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
category="lost_sales",
|
||||
frequency=freq,
|
||||
severity=sev,
|
||||
priority_rank=1,
|
||||
sample_evidence=["Es muy caro para mí"],
|
||||
)
|
||||
|
||||
node_dict = node.to_dict()
|
||||
|
||||
assert node_dict["driver_code"] == "PRICE_TOO_HIGH"
|
||||
assert node_dict["priority_rank"] == 1
|
||||
assert "frequency" in node_dict
|
||||
assert "severity" in node_dict
|
||||
|
||||
|
||||
class TestEmergentPatterns:
|
||||
"""Tests for emergent pattern extraction."""
|
||||
|
||||
def test_extract_emergent(self):
|
||||
"""Test emergent pattern extraction."""
|
||||
base_observed = ObservedFeatures(audio_duration_sec=60.0, events=[])
|
||||
base_trace = Traceability(
|
||||
schema_version="1.0.0",
|
||||
prompt_version="v1.0",
|
||||
model_id="gpt-4o-mini",
|
||||
)
|
||||
|
||||
analyses = [
|
||||
CallAnalysis(
|
||||
call_id="EMG001",
|
||||
batch_id="test",
|
||||
status=ProcessingStatus.SUCCESS,
|
||||
observed=base_observed,
|
||||
outcome=CallOutcome.SALE_LOST,
|
||||
lost_sales_drivers=[
|
||||
RCALabel(
|
||||
driver_code="OTHER_EMERGENT",
|
||||
confidence=0.7,
|
||||
evidence_spans=[
|
||||
EvidenceSpan(text="Nuevo patrón", start_time=0, end_time=1)
|
||||
],
|
||||
proposed_label="NEW_PATTERN",
|
||||
)
|
||||
],
|
||||
poor_cx_drivers=[],
|
||||
traceability=base_trace,
|
||||
)
|
||||
]
|
||||
|
||||
calculator = StatisticsCalculator()
|
||||
emergent = calculator.extract_emergent_patterns(analyses)
|
||||
|
||||
assert len(emergent) == 1
|
||||
assert emergent[0]["proposed_label"] == "NEW_PATTERN"
|
||||
assert emergent[0]["occurrences"] == 1
|
||||
480
tests/unit/test_compression.py
Normal file
480
tests/unit/test_compression.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""
|
||||
CXInsights - Compression Module Tests
|
||||
|
||||
Tests for transcript compression and semantic extraction.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.compression.compressor import (
|
||||
TranscriptCompressor,
|
||||
compress_for_prompt,
|
||||
compress_transcript,
|
||||
)
|
||||
from src.compression.models import (
|
||||
AgentOffer,
|
||||
CompressionConfig,
|
||||
CompressedTranscript,
|
||||
CustomerIntent,
|
||||
CustomerObjection,
|
||||
IntentType,
|
||||
KeyMoment,
|
||||
ObjectionType,
|
||||
ResolutionStatement,
|
||||
ResolutionType,
|
||||
)
|
||||
from src.transcription.models import SpeakerTurn, Transcript, TranscriptMetadata
|
||||
|
||||
|
||||
class TestCustomerIntent:
|
||||
"""Tests for CustomerIntent model."""
|
||||
|
||||
def test_to_prompt_text(self):
|
||||
"""Test prompt text generation."""
|
||||
intent = CustomerIntent(
|
||||
intent_type=IntentType.CANCEL,
|
||||
description="Customer wants to cancel service",
|
||||
confidence=0.9,
|
||||
verbatim_quotes=["quiero cancelar mi servicio"],
|
||||
)
|
||||
|
||||
text = intent.to_prompt_text()
|
||||
|
||||
assert "CANCEL" in text
|
||||
assert "quiero cancelar" in text
|
||||
|
||||
def test_to_prompt_text_no_quotes(self):
|
||||
"""Test prompt text without quotes."""
|
||||
intent = CustomerIntent(
|
||||
intent_type=IntentType.INQUIRY,
|
||||
description="Customer asking about prices",
|
||||
confidence=0.8,
|
||||
)
|
||||
|
||||
text = intent.to_prompt_text()
|
||||
|
||||
assert "INQUIRY" in text
|
||||
assert "Evidence:" not in text
|
||||
|
||||
|
||||
class TestCustomerObjection:
|
||||
"""Tests for CustomerObjection model."""
|
||||
|
||||
def test_addressed_status(self):
|
||||
"""Test addressed status in prompt text."""
|
||||
addressed = CustomerObjection(
|
||||
objection_type=ObjectionType.PRICE,
|
||||
description="Too expensive",
|
||||
turn_index=5,
|
||||
verbatim="Es muy caro",
|
||||
addressed=True,
|
||||
)
|
||||
|
||||
unaddressed = CustomerObjection(
|
||||
objection_type=ObjectionType.PRICE,
|
||||
description="Too expensive",
|
||||
turn_index=5,
|
||||
verbatim="Es muy caro",
|
||||
addressed=False,
|
||||
)
|
||||
|
||||
assert "[ADDRESSED]" in addressed.to_prompt_text()
|
||||
assert "[UNADDRESSED]" in unaddressed.to_prompt_text()
|
||||
|
||||
|
||||
class TestAgentOffer:
|
||||
"""Tests for AgentOffer model."""
|
||||
|
||||
def test_acceptance_status(self):
|
||||
"""Test acceptance status in prompt text."""
|
||||
accepted = AgentOffer(
|
||||
offer_type="discount",
|
||||
description="10% discount",
|
||||
turn_index=10,
|
||||
verbatim="Le ofrezco un 10% de descuento",
|
||||
accepted=True,
|
||||
)
|
||||
|
||||
rejected = AgentOffer(
|
||||
offer_type="discount",
|
||||
description="10% discount",
|
||||
turn_index=10,
|
||||
verbatim="Le ofrezco un 10% de descuento",
|
||||
accepted=False,
|
||||
)
|
||||
|
||||
pending = AgentOffer(
|
||||
offer_type="discount",
|
||||
description="10% discount",
|
||||
turn_index=10,
|
||||
verbatim="Le ofrezco un 10% de descuento",
|
||||
accepted=None,
|
||||
)
|
||||
|
||||
assert "[ACCEPTED]" in accepted.to_prompt_text()
|
||||
assert "[REJECTED]" in rejected.to_prompt_text()
|
||||
assert "[ACCEPTED]" not in pending.to_prompt_text()
|
||||
assert "[REJECTED]" not in pending.to_prompt_text()
|
||||
|
||||
|
||||
class TestCompressedTranscript:
|
||||
"""Tests for CompressedTranscript model."""
|
||||
|
||||
def test_to_prompt_text_basic(self):
|
||||
"""Test basic prompt text generation."""
|
||||
compressed = CompressedTranscript(
|
||||
call_id="TEST001",
|
||||
customer_intents=[
|
||||
CustomerIntent(
|
||||
intent_type=IntentType.CANCEL,
|
||||
description="Wants to cancel",
|
||||
confidence=0.9,
|
||||
)
|
||||
],
|
||||
objections=[
|
||||
CustomerObjection(
|
||||
objection_type=ObjectionType.PRICE,
|
||||
description="Too expensive",
|
||||
turn_index=5,
|
||||
verbatim="Es caro",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
text = compressed.to_prompt_text()
|
||||
|
||||
assert "CUSTOMER INTENT" in text
|
||||
assert "CUSTOMER OBJECTIONS" in text
|
||||
assert "CANCEL" in text
|
||||
assert "price" in text.lower()
|
||||
|
||||
def test_to_prompt_text_empty(self):
|
||||
"""Test prompt text with no elements."""
|
||||
compressed = CompressedTranscript(call_id="EMPTY001")
|
||||
|
||||
text = compressed.to_prompt_text()
|
||||
|
||||
# Should be mostly empty but not fail
|
||||
assert len(text) >= 0
|
||||
|
||||
def test_to_prompt_text_truncation(self):
|
||||
"""Test prompt text truncation."""
|
||||
compressed = CompressedTranscript(
|
||||
call_id="LONG001",
|
||||
key_moments=[
|
||||
KeyMoment(
|
||||
moment_type="test",
|
||||
description="x" * 500,
|
||||
turn_index=i,
|
||||
start_time=float(i),
|
||||
verbatim="y" * 200,
|
||||
speaker="customer",
|
||||
)
|
||||
for i in range(50)
|
||||
],
|
||||
)
|
||||
|
||||
text = compressed.to_prompt_text(max_chars=1000)
|
||||
|
||||
assert len(text) <= 1000
|
||||
assert "truncated" in text
|
||||
|
||||
def test_get_stats(self):
|
||||
"""Test statistics generation."""
|
||||
compressed = CompressedTranscript(
|
||||
call_id="STATS001",
|
||||
original_turn_count=50,
|
||||
original_char_count=10000,
|
||||
compressed_char_count=2000,
|
||||
compression_ratio=0.8,
|
||||
customer_intents=[
|
||||
CustomerIntent(IntentType.CANCEL, "test", 0.9)
|
||||
],
|
||||
objections=[
|
||||
CustomerObjection(ObjectionType.PRICE, "test", 0, "test")
|
||||
],
|
||||
)
|
||||
|
||||
stats = compressed.get_stats()
|
||||
|
||||
assert stats["original_turns"] == 50
|
||||
assert stats["original_chars"] == 10000
|
||||
assert stats["compressed_chars"] == 2000
|
||||
assert stats["compression_ratio"] == 0.8
|
||||
assert stats["intents_extracted"] == 1
|
||||
assert stats["objections_extracted"] == 1
|
||||
|
||||
|
||||
class TestTranscriptCompressor:
|
||||
"""Tests for TranscriptCompressor."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_transcript(self):
|
||||
"""Create a sample transcript for testing."""
|
||||
return Transcript(
|
||||
call_id="COMP001",
|
||||
turns=[
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Hola, buenos días, gracias por llamar.",
|
||||
start_time=0.0,
|
||||
end_time=2.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="customer",
|
||||
text="Hola, quiero cancelar mi servicio porque es muy caro.",
|
||||
start_time=2.5,
|
||||
end_time=5.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Entiendo. Le puedo ofrecer un 20% de descuento.",
|
||||
start_time=5.5,
|
||||
end_time=8.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="customer",
|
||||
text="No gracias, ya tomé la decisión.",
|
||||
start_time=8.5,
|
||||
end_time=10.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Entiendo, si cambia de opinión estamos para ayudarle.",
|
||||
start_time=10.5,
|
||||
end_time=13.0,
|
||||
),
|
||||
],
|
||||
metadata=TranscriptMetadata(
|
||||
audio_duration_sec=60.0,
|
||||
language="es",
|
||||
),
|
||||
)
|
||||
|
||||
def test_compress_extracts_intent(self, sample_transcript):
|
||||
"""Test that cancel intent is extracted."""
|
||||
compressor = TranscriptCompressor()
|
||||
compressed = compressor.compress(sample_transcript)
|
||||
|
||||
assert len(compressed.customer_intents) > 0
|
||||
assert any(
|
||||
i.intent_type == IntentType.CANCEL
|
||||
for i in compressed.customer_intents
|
||||
)
|
||||
|
||||
def test_compress_extracts_price_objection(self, sample_transcript):
|
||||
"""Test that price objection is extracted."""
|
||||
compressor = TranscriptCompressor()
|
||||
compressed = compressor.compress(sample_transcript)
|
||||
|
||||
assert len(compressed.objections) > 0
|
||||
assert any(
|
||||
o.objection_type == ObjectionType.PRICE
|
||||
for o in compressed.objections
|
||||
)
|
||||
|
||||
def test_compress_extracts_offer(self, sample_transcript):
|
||||
"""Test that agent offer is extracted."""
|
||||
compressor = TranscriptCompressor()
|
||||
compressed = compressor.compress(sample_transcript)
|
||||
|
||||
assert len(compressed.agent_offers) > 0
|
||||
|
||||
def test_compress_extracts_key_moments(self, sample_transcript):
|
||||
"""Test that key moments are extracted."""
|
||||
compressor = TranscriptCompressor()
|
||||
compressed = compressor.compress(sample_transcript)
|
||||
|
||||
# Should find rejection and firm_decision
|
||||
moment_types = [m.moment_type for m in compressed.key_moments]
|
||||
assert len(moment_types) > 0
|
||||
|
||||
def test_compression_ratio(self, sample_transcript):
|
||||
"""Test that compression ratio is calculated."""
|
||||
compressor = TranscriptCompressor()
|
||||
compressed = compressor.compress(sample_transcript)
|
||||
|
||||
assert compressed.compression_ratio > 0
|
||||
assert compressed.original_char_count > compressed.compressed_char_count
|
||||
|
||||
def test_compression_respects_max_limits(self, sample_transcript):
|
||||
"""Test that max limits are respected."""
|
||||
config = CompressionConfig(
|
||||
max_intents=1,
|
||||
max_offers=1,
|
||||
max_objections=1,
|
||||
max_key_moments=2,
|
||||
)
|
||||
compressor = TranscriptCompressor(config=config)
|
||||
compressed = compressor.compress(sample_transcript)
|
||||
|
||||
assert len(compressed.customer_intents) <= 1
|
||||
assert len(compressed.agent_offers) <= 1
|
||||
assert len(compressed.objections) <= 1
|
||||
assert len(compressed.key_moments) <= 2
|
||||
|
||||
def test_generates_summary(self, sample_transcript):
|
||||
"""Test that summary is generated."""
|
||||
compressor = TranscriptCompressor()
|
||||
compressed = compressor.compress(sample_transcript)
|
||||
|
||||
assert len(compressed.call_summary) > 0
|
||||
assert "cancel" in compressed.call_summary.lower()
|
||||
|
||||
|
||||
class TestIntentExtraction:
|
||||
"""Tests for specific intent patterns."""
|
||||
|
||||
def make_transcript(self, customer_text: str) -> Transcript:
|
||||
"""Helper to create transcript with customer turn."""
|
||||
return Transcript(
|
||||
call_id="INT001",
|
||||
turns=[
|
||||
SpeakerTurn(speaker="agent", text="Hola", start_time=0, end_time=1),
|
||||
SpeakerTurn(speaker="customer", text=customer_text, start_time=1, end_time=3),
|
||||
],
|
||||
)
|
||||
|
||||
def test_cancel_intent_patterns(self):
|
||||
"""Test various cancel intent patterns."""
|
||||
patterns = [
|
||||
"Quiero cancelar mi servicio",
|
||||
"Quiero dar de baja mi cuenta",
|
||||
"No quiero continuar con el servicio",
|
||||
]
|
||||
|
||||
compressor = TranscriptCompressor()
|
||||
|
||||
for pattern in patterns:
|
||||
transcript = self.make_transcript(pattern)
|
||||
compressed = compressor.compress(transcript)
|
||||
assert any(
|
||||
i.intent_type == IntentType.CANCEL
|
||||
for i in compressed.customer_intents
|
||||
), f"Failed for: {pattern}"
|
||||
|
||||
def test_purchase_intent_patterns(self):
|
||||
"""Test purchase intent patterns."""
|
||||
patterns = [
|
||||
"Quiero contratar el plan premium",
|
||||
"Me interesa comprar el servicio",
|
||||
]
|
||||
|
||||
compressor = TranscriptCompressor()
|
||||
|
||||
for pattern in patterns:
|
||||
transcript = self.make_transcript(pattern)
|
||||
compressed = compressor.compress(transcript)
|
||||
assert any(
|
||||
i.intent_type == IntentType.PURCHASE
|
||||
for i in compressed.customer_intents
|
||||
), f"Failed for: {pattern}"
|
||||
|
||||
def test_complaint_intent_patterns(self):
|
||||
"""Test complaint intent patterns."""
|
||||
patterns = [
|
||||
"Tengo un problema con mi factura",
|
||||
"Estoy muy molesto con el servicio",
|
||||
"Quiero poner una queja",
|
||||
]
|
||||
|
||||
compressor = TranscriptCompressor()
|
||||
|
||||
for pattern in patterns:
|
||||
transcript = self.make_transcript(pattern)
|
||||
compressed = compressor.compress(transcript)
|
||||
assert any(
|
||||
i.intent_type == IntentType.COMPLAINT
|
||||
for i in compressed.customer_intents
|
||||
), f"Failed for: {pattern}"
|
||||
|
||||
|
||||
class TestObjectionExtraction:
|
||||
"""Tests for objection pattern extraction."""
|
||||
|
||||
def make_transcript(self, customer_text: str) -> Transcript:
|
||||
"""Helper to create transcript with customer turn."""
|
||||
return Transcript(
|
||||
call_id="OBJ001",
|
||||
turns=[
|
||||
SpeakerTurn(speaker="agent", text="Hola", start_time=0, end_time=1),
|
||||
SpeakerTurn(speaker="customer", text=customer_text, start_time=1, end_time=3),
|
||||
],
|
||||
)
|
||||
|
||||
def test_price_objection_patterns(self):
|
||||
"""Test price objection patterns."""
|
||||
patterns = [
|
||||
"Es muy caro para mí",
|
||||
"Es demasiado costoso",
|
||||
"No tengo el dinero ahora",
|
||||
"Está fuera de mi presupuesto",
|
||||
]
|
||||
|
||||
compressor = TranscriptCompressor()
|
||||
|
||||
for pattern in patterns:
|
||||
transcript = self.make_transcript(pattern)
|
||||
compressed = compressor.compress(transcript)
|
||||
assert any(
|
||||
o.objection_type == ObjectionType.PRICE
|
||||
for o in compressed.objections
|
||||
), f"Failed for: {pattern}"
|
||||
|
||||
def test_timing_objection_patterns(self):
|
||||
"""Test timing objection patterns."""
|
||||
patterns = [
|
||||
"No es buen momento",
|
||||
"Déjame pensarlo",
|
||||
"Lo voy a pensar",
|
||||
]
|
||||
|
||||
compressor = TranscriptCompressor()
|
||||
|
||||
for pattern in patterns:
|
||||
transcript = self.make_transcript(pattern)
|
||||
compressed = compressor.compress(transcript)
|
||||
assert any(
|
||||
o.objection_type == ObjectionType.TIMING
|
||||
for o in compressed.objections
|
||||
), f"Failed for: {pattern}"
|
||||
|
||||
|
||||
class TestConvenienceFunctions:
|
||||
"""Tests for convenience functions."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_transcript(self):
|
||||
"""Create sample transcript."""
|
||||
return Transcript(
|
||||
call_id="CONV001",
|
||||
turns=[
|
||||
SpeakerTurn(speaker="agent", text="Hola", start_time=0, end_time=1),
|
||||
SpeakerTurn(
|
||||
speaker="customer",
|
||||
text="Quiero cancelar, es muy caro",
|
||||
start_time=1,
|
||||
end_time=3,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def test_compress_transcript(self, sample_transcript):
|
||||
"""Test compress_transcript function."""
|
||||
compressed = compress_transcript(sample_transcript)
|
||||
|
||||
assert isinstance(compressed, CompressedTranscript)
|
||||
assert compressed.call_id == "CONV001"
|
||||
|
||||
def test_compress_for_prompt(self, sample_transcript):
|
||||
"""Test compress_for_prompt function."""
|
||||
text = compress_for_prompt(sample_transcript)
|
||||
|
||||
assert isinstance(text, str)
|
||||
assert len(text) > 0
|
||||
|
||||
def test_compress_for_prompt_max_chars(self, sample_transcript):
|
||||
"""Test max_chars parameter."""
|
||||
text = compress_for_prompt(sample_transcript, max_chars=100)
|
||||
|
||||
assert len(text) <= 100
|
||||
394
tests/unit/test_features.py
Normal file
394
tests/unit/test_features.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""
|
||||
CXInsights - Feature Extraction Tests
|
||||
|
||||
Tests for deterministic feature extraction.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.features.event_detector import EventDetector, EventDetectorConfig, detect_events
|
||||
from src.features.extractor import FeatureExtractor, extract_features
|
||||
from src.features.turn_metrics import TurnMetricsCalculator, calculate_turn_metrics
|
||||
from src.models.call_analysis import EventType
|
||||
from src.transcription.models import SpeakerTurn, Transcript, TranscriptMetadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_transcript():
|
||||
"""Create a sample transcript for testing."""
|
||||
return Transcript(
|
||||
call_id="TEST001",
|
||||
turns=[
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Buenos días, ¿en qué puedo ayudarle?",
|
||||
start_time=0.0,
|
||||
end_time=3.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="customer",
|
||||
text="Hola, quiero cancelar mi servicio.",
|
||||
start_time=3.5,
|
||||
end_time=6.5,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Entiendo. Un momento, por favor, le pongo en espera mientras consulto.",
|
||||
start_time=7.0,
|
||||
end_time=12.0,
|
||||
),
|
||||
# Silence gap (hold)
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Gracias por la espera. Le cuento que tenemos una oferta especial.",
|
||||
start_time=45.0,
|
||||
end_time=52.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="customer",
|
||||
text="No me interesa, es demasiado caro.",
|
||||
start_time=52.5,
|
||||
end_time=56.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Le voy a transferir con el departamento de retenciones.",
|
||||
start_time=56.5,
|
||||
end_time=61.0,
|
||||
),
|
||||
],
|
||||
metadata=TranscriptMetadata(
|
||||
audio_duration_sec=120.0,
|
||||
audio_file="TEST001.mp3",
|
||||
provider="test",
|
||||
speaker_count=2,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transcript_with_interruptions():
|
||||
"""Create a transcript with overlapping speech."""
|
||||
return Transcript(
|
||||
call_id="TEST002",
|
||||
turns=[
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Le explico cómo funciona el proceso...",
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="customer",
|
||||
text="Pero es que yo ya lo sé...",
|
||||
start_time=4.5, # Starts before agent ends
|
||||
end_time=7.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Perdone, le decía que...",
|
||||
start_time=6.8, # Starts before customer ends
|
||||
end_time=10.0,
|
||||
),
|
||||
],
|
||||
metadata=TranscriptMetadata(
|
||||
audio_duration_sec=60.0,
|
||||
audio_file="TEST002.mp3",
|
||||
provider="test",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def transcript_with_silences():
|
||||
"""Create a transcript with significant silences."""
|
||||
return Transcript(
|
||||
call_id="TEST003",
|
||||
turns=[
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Voy a comprobar su cuenta.",
|
||||
start_time=0.0,
|
||||
end_time=3.0,
|
||||
),
|
||||
# 10 second gap
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Ya tengo la información.",
|
||||
start_time=13.0,
|
||||
end_time=16.0,
|
||||
),
|
||||
# 8 second gap
|
||||
SpeakerTurn(
|
||||
speaker="customer",
|
||||
text="¿Y qué dice?",
|
||||
start_time=24.0,
|
||||
end_time=26.0,
|
||||
),
|
||||
],
|
||||
metadata=TranscriptMetadata(
|
||||
audio_duration_sec=30.0,
|
||||
audio_file="TEST003.mp3",
|
||||
provider="test",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestEventDetector:
|
||||
"""Tests for EventDetector."""
|
||||
|
||||
def test_detect_hold_start(self, sample_transcript):
|
||||
"""Test detection of hold start patterns."""
|
||||
events = detect_events(sample_transcript)
|
||||
|
||||
hold_starts = [e for e in events if e.event_type == EventType.HOLD_START]
|
||||
assert len(hold_starts) >= 1
|
||||
# Should detect "Un momento, por favor, le pongo en espera"
|
||||
|
||||
def test_detect_hold_end(self, sample_transcript):
|
||||
"""Test detection of hold end patterns."""
|
||||
events = detect_events(sample_transcript)
|
||||
|
||||
hold_ends = [e for e in events if e.event_type == EventType.HOLD_END]
|
||||
assert len(hold_ends) >= 1
|
||||
# Should detect "Gracias por la espera"
|
||||
|
||||
def test_detect_transfer(self, sample_transcript):
|
||||
"""Test detection of transfer patterns."""
|
||||
events = detect_events(sample_transcript)
|
||||
|
||||
transfers = [e for e in events if e.event_type == EventType.TRANSFER]
|
||||
assert len(transfers) >= 1
|
||||
# Should detect "Le voy a transferir"
|
||||
|
||||
def test_detect_silence(self, transcript_with_silences):
|
||||
"""Test detection of significant silences."""
|
||||
config = EventDetectorConfig(silence_threshold_sec=5.0)
|
||||
detector = EventDetector(config)
|
||||
events = detector.detect_all(transcript_with_silences)
|
||||
|
||||
silences = [e for e in events if e.event_type == EventType.SILENCE]
|
||||
assert len(silences) == 2 # Two gaps > 5 seconds
|
||||
assert silences[0].duration_sec == 10.0
|
||||
assert silences[1].duration_sec == 8.0
|
||||
|
||||
def test_detect_interruptions(self, transcript_with_interruptions):
|
||||
"""Test detection of interruptions."""
|
||||
events = detect_events(transcript_with_interruptions)
|
||||
|
||||
interruptions = [e for e in events if e.event_type == EventType.INTERRUPTION]
|
||||
assert len(interruptions) == 2 # Two overlapping segments
|
||||
|
||||
def test_events_sorted_by_time(self, sample_transcript):
|
||||
"""Test that events are sorted by start time."""
|
||||
events = detect_events(sample_transcript)
|
||||
|
||||
for i in range(1, len(events)):
|
||||
assert events[i].start_time >= events[i - 1].start_time
|
||||
|
||||
def test_event_has_observed_source(self, sample_transcript):
|
||||
"""Test that all events have source='observed'."""
|
||||
events = detect_events(sample_transcript)
|
||||
|
||||
for event in events:
|
||||
assert event.source == "observed"
|
||||
|
||||
|
||||
class TestTurnMetrics:
|
||||
"""Tests for TurnMetricsCalculator."""
|
||||
|
||||
def test_turn_counts(self, sample_transcript):
|
||||
"""Test turn counting."""
|
||||
metrics = calculate_turn_metrics(sample_transcript)
|
||||
|
||||
assert metrics.total_turns == 6
|
||||
assert metrics.agent_turns == 4
|
||||
assert metrics.customer_turns == 2
|
||||
|
||||
def test_talk_ratios(self, sample_transcript):
|
||||
"""Test talk ratio calculations."""
|
||||
metrics = calculate_turn_metrics(sample_transcript)
|
||||
|
||||
# Ratios should be between 0 and 1
|
||||
assert 0 <= metrics.agent_talk_ratio <= 1
|
||||
assert 0 <= metrics.customer_talk_ratio <= 1
|
||||
assert 0 <= metrics.silence_ratio <= 1
|
||||
|
||||
# Sum should be approximately 1 (may have gaps)
|
||||
total = metrics.agent_talk_ratio + metrics.customer_talk_ratio + metrics.silence_ratio
|
||||
assert total <= 1.1 # Allow small rounding
|
||||
|
||||
def test_interruption_count(self, transcript_with_interruptions):
|
||||
"""Test interruption counting in metrics."""
|
||||
metrics = calculate_turn_metrics(transcript_with_interruptions)
|
||||
|
||||
assert metrics.interruption_count == 2
|
||||
|
||||
def test_avg_turn_duration(self, sample_transcript):
|
||||
"""Test average turn duration calculation."""
|
||||
metrics = calculate_turn_metrics(sample_transcript)
|
||||
|
||||
assert metrics.avg_turn_duration_sec > 0
|
||||
|
||||
def test_metrics_has_observed_source(self, sample_transcript):
|
||||
"""Test that metrics have source='observed'."""
|
||||
metrics = calculate_turn_metrics(sample_transcript)
|
||||
|
||||
assert metrics.source == "observed"
|
||||
|
||||
def test_empty_transcript(self):
|
||||
"""Test handling of empty transcript."""
|
||||
empty = Transcript(
|
||||
call_id="EMPTY",
|
||||
turns=[],
|
||||
metadata=TranscriptMetadata(
|
||||
audio_duration_sec=0.0,
|
||||
audio_file="empty.mp3",
|
||||
provider="test",
|
||||
),
|
||||
)
|
||||
|
||||
metrics = calculate_turn_metrics(empty)
|
||||
|
||||
assert metrics.total_turns == 0
|
||||
assert metrics.agent_turns == 0
|
||||
assert metrics.customer_turns == 0
|
||||
|
||||
|
||||
class TestFeatureExtractor:
|
||||
"""Tests for FeatureExtractor."""
|
||||
|
||||
def test_extract_features(self, sample_transcript):
|
||||
"""Test complete feature extraction."""
|
||||
features = extract_features(sample_transcript)
|
||||
|
||||
assert features.call_id == "TEST001"
|
||||
assert features.audio_duration_sec == 120.0
|
||||
assert features.language == "es"
|
||||
|
||||
def test_features_have_events(self, sample_transcript):
|
||||
"""Test that features include detected events."""
|
||||
features = extract_features(sample_transcript)
|
||||
|
||||
assert len(features.events) > 0
|
||||
|
||||
def test_features_have_metrics(self, sample_transcript):
|
||||
"""Test that features include turn metrics."""
|
||||
features = extract_features(sample_transcript)
|
||||
|
||||
assert features.turn_metrics is not None
|
||||
assert features.turn_metrics.total_turns == 6
|
||||
|
||||
def test_hold_aggregation(self, sample_transcript):
|
||||
"""Test hold count aggregation."""
|
||||
features = extract_features(sample_transcript)
|
||||
|
||||
# Should have at least one hold
|
||||
assert features.hold_count >= 1
|
||||
|
||||
def test_transfer_aggregation(self, sample_transcript):
|
||||
"""Test transfer count aggregation."""
|
||||
features = extract_features(sample_transcript)
|
||||
|
||||
assert features.transfer_count >= 1
|
||||
|
||||
def test_silence_aggregation(self, transcript_with_silences):
|
||||
"""Test silence count aggregation."""
|
||||
features = extract_features(transcript_with_silences)
|
||||
|
||||
assert features.silence_count == 2
|
||||
|
||||
def test_interruption_aggregation(self, transcript_with_interruptions):
|
||||
"""Test interruption count aggregation."""
|
||||
features = extract_features(transcript_with_interruptions)
|
||||
|
||||
assert features.interruption_count == 2
|
||||
|
||||
def test_deterministic_output(self, sample_transcript):
|
||||
"""Test that extraction is deterministic."""
|
||||
features1 = extract_features(sample_transcript)
|
||||
features2 = extract_features(sample_transcript)
|
||||
|
||||
# Same input should produce same output
|
||||
assert features1.hold_count == features2.hold_count
|
||||
assert features1.transfer_count == features2.transfer_count
|
||||
assert features1.silence_count == features2.silence_count
|
||||
assert len(features1.events) == len(features2.events)
|
||||
|
||||
|
||||
class TestSpanishPatterns:
|
||||
"""Tests for Spanish language pattern detection."""
|
||||
|
||||
def test_hold_patterns_spanish(self):
|
||||
"""Test various Spanish hold patterns."""
|
||||
patterns_to_test = [
|
||||
("Un momento, por favor", True),
|
||||
("Le voy a poner en espera", True),
|
||||
("Espere un segundo", True),
|
||||
("No cuelgue", True),
|
||||
("Déjeme verificar", True),
|
||||
("Buenos días", False),
|
||||
("Gracias por llamar", False),
|
||||
]
|
||||
|
||||
for text, should_match in patterns_to_test:
|
||||
transcript = Transcript(
|
||||
call_id="TEST",
|
||||
turns=[
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text=text,
|
||||
start_time=0.0,
|
||||
end_time=3.0,
|
||||
),
|
||||
],
|
||||
metadata=TranscriptMetadata(
|
||||
audio_duration_sec=10.0,
|
||||
audio_file="test.mp3",
|
||||
provider="test",
|
||||
),
|
||||
)
|
||||
|
||||
events = detect_events(transcript)
|
||||
hold_starts = [e for e in events if e.event_type == EventType.HOLD_START]
|
||||
|
||||
if should_match:
|
||||
assert len(hold_starts) >= 1, f"Should match: {text}"
|
||||
else:
|
||||
assert len(hold_starts) == 0, f"Should not match: {text}"
|
||||
|
||||
def test_transfer_patterns_spanish(self):
|
||||
"""Test various Spanish transfer patterns."""
|
||||
patterns_to_test = [
|
||||
("Le voy a transferir con el departamento de ventas", True),
|
||||
("Le paso con mi compañero", True),
|
||||
("Le comunico con facturación", True),
|
||||
("Va a ser transferido", True),
|
||||
("Gracias por su paciencia", False),
|
||||
]
|
||||
|
||||
for text, should_match in patterns_to_test:
|
||||
transcript = Transcript(
|
||||
call_id="TEST",
|
||||
turns=[
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text=text,
|
||||
start_time=0.0,
|
||||
end_time=3.0,
|
||||
),
|
||||
],
|
||||
metadata=TranscriptMetadata(
|
||||
audio_duration_sec=10.0,
|
||||
audio_file="test.mp3",
|
||||
provider="test",
|
||||
),
|
||||
)
|
||||
|
||||
events = detect_events(transcript)
|
||||
transfers = [e for e in events if e.event_type == EventType.TRANSFER]
|
||||
|
||||
if should_match:
|
||||
assert len(transfers) >= 1, f"Should match: {text}"
|
||||
else:
|
||||
assert len(transfers) == 0, f"Should not match: {text}"
|
||||
393
tests/unit/test_inference.py
Normal file
393
tests/unit/test_inference.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
CXInsights - Inference Module Tests
|
||||
|
||||
Tests for LLM client, prompt manager, and analyzer.
|
||||
Uses mocks for LLM calls to avoid API costs.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.inference.client import LLMClient, LLMClientConfig, LLMResponse
|
||||
from src.inference.prompt_manager import (
|
||||
PromptManager,
|
||||
PromptTemplate,
|
||||
format_events_for_prompt,
|
||||
format_transcript_for_prompt,
|
||||
)
|
||||
from src.models.call_analysis import Event, EventType
|
||||
from src.transcription.models import SpeakerTurn, Transcript, TranscriptMetadata
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
"""Tests for LLMResponse."""
|
||||
|
||||
def test_cost_estimate(self):
|
||||
"""Test cost estimation based on tokens."""
|
||||
response = LLMResponse(
|
||||
content="test",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
total_tokens=1500,
|
||||
)
|
||||
|
||||
# GPT-4o-mini: $0.15/1M input, $0.60/1M output
|
||||
expected = (1000 / 1_000_000) * 0.15 + (500 / 1_000_000) * 0.60
|
||||
assert abs(response.cost_estimate_usd - expected) < 0.0001
|
||||
|
||||
def test_success_flag(self):
|
||||
"""Test success flag."""
|
||||
success = LLMResponse(content="test", success=True)
|
||||
failure = LLMResponse(content="", success=False, error="API error")
|
||||
|
||||
assert success.success is True
|
||||
assert failure.success is False
|
||||
|
||||
def test_parsed_json(self):
|
||||
"""Test parsed JSON storage."""
|
||||
response = LLMResponse(
|
||||
content='{"key": "value"}',
|
||||
parsed_json={"key": "value"},
|
||||
)
|
||||
|
||||
assert response.parsed_json == {"key": "value"}
|
||||
|
||||
|
||||
class TestLLMClientConfig:
|
||||
"""Tests for LLMClientConfig."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration."""
|
||||
config = LLMClientConfig()
|
||||
|
||||
assert config.model == "gpt-4o-mini"
|
||||
assert config.temperature == 0.1
|
||||
assert config.max_tokens == 4000
|
||||
assert config.json_mode is True
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom configuration."""
|
||||
config = LLMClientConfig(
|
||||
model="gpt-4o",
|
||||
temperature=0.5,
|
||||
max_tokens=8000,
|
||||
)
|
||||
|
||||
assert config.model == "gpt-4o"
|
||||
assert config.temperature == 0.5
|
||||
assert config.max_tokens == 8000
|
||||
|
||||
|
||||
class TestLLMClient:
|
||||
"""Tests for LLMClient."""
|
||||
|
||||
def test_requires_api_key(self):
|
||||
"""Test that API key is required."""
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
with pytest.raises(ValueError, match="API key required"):
|
||||
LLMClient()
|
||||
|
||||
def test_parse_json_valid(self):
|
||||
"""Test JSON parsing with valid JSON."""
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
client = LLMClient()
|
||||
result = client._parse_json('{"key": "value"}')
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_parse_json_with_markdown(self):
|
||||
"""Test JSON parsing with markdown code blocks."""
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
client = LLMClient()
|
||||
content = '```json\n{"key": "value"}\n```'
|
||||
result = client._parse_json(content)
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_parse_json_extract_from_text(self):
|
||||
"""Test JSON extraction from surrounding text."""
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
client = LLMClient()
|
||||
content = 'Here is the result: {"key": "value"} end.'
|
||||
result = client._parse_json(content)
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_parse_json_invalid(self):
|
||||
"""Test JSON parsing with invalid JSON."""
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
client = LLMClient()
|
||||
result = client._parse_json("not json at all")
|
||||
assert result is None
|
||||
|
||||
def test_usage_stats_tracking(self):
|
||||
"""Test usage statistics tracking."""
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
client = LLMClient()
|
||||
|
||||
# Initially zero
|
||||
stats = client.get_usage_stats()
|
||||
assert stats["total_calls"] == 0
|
||||
assert stats["total_tokens"] == 0
|
||||
|
||||
def test_reset_usage_stats(self):
|
||||
"""Test resetting usage statistics."""
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
client = LLMClient()
|
||||
client._total_calls = 10
|
||||
client._total_tokens = 5000
|
||||
|
||||
client.reset_usage_stats()
|
||||
|
||||
stats = client.get_usage_stats()
|
||||
assert stats["total_calls"] == 0
|
||||
assert stats["total_tokens"] == 0
|
||||
|
||||
|
||||
class TestPromptTemplate:
|
||||
"""Tests for PromptTemplate."""
|
||||
|
||||
def test_render_basic(self):
|
||||
"""Test basic template rendering."""
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="v1.0",
|
||||
system="You are analyzing call $call_id",
|
||||
user="Transcript: $transcript",
|
||||
)
|
||||
|
||||
system, user = template.render(
|
||||
call_id="CALL001",
|
||||
transcript="Hello world",
|
||||
)
|
||||
|
||||
assert "CALL001" in system
|
||||
assert "Hello world" in user
|
||||
|
||||
def test_render_missing_var(self):
|
||||
"""Test rendering with missing variable (safe_substitute)."""
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="v1.0",
|
||||
system="Call $call_id in $queue",
|
||||
user="Text",
|
||||
)
|
||||
|
||||
system, user = template.render(call_id="CALL001")
|
||||
# safe_substitute leaves $queue as-is
|
||||
assert "$queue" in system
|
||||
|
||||
def test_to_messages(self):
|
||||
"""Test message list generation."""
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="v1.0",
|
||||
system="System message",
|
||||
user="User message",
|
||||
)
|
||||
|
||||
messages = template.to_messages()
|
||||
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[1]["role"] == "user"
|
||||
|
||||
|
||||
class TestPromptManager:
|
||||
"""Tests for PromptManager."""
|
||||
|
||||
def test_load_call_analysis_prompt(self, config_dir):
|
||||
"""Test loading call analysis prompt."""
|
||||
manager = PromptManager(config_dir / "prompts")
|
||||
template = manager.load("call_analysis", "v1.0")
|
||||
|
||||
assert template.name == "call_analysis"
|
||||
assert template.version == "v1.0"
|
||||
assert len(template.system) > 0
|
||||
assert len(template.user) > 0
|
||||
|
||||
def test_load_nonexistent_prompt(self, config_dir):
|
||||
"""Test loading non-existent prompt."""
|
||||
manager = PromptManager(config_dir / "prompts")
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
manager.load("nonexistent", "v1.0")
|
||||
|
||||
def test_get_active_version(self, config_dir):
|
||||
"""Test getting active version."""
|
||||
manager = PromptManager(config_dir / "prompts")
|
||||
version = manager.get_active_version("call_analysis")
|
||||
|
||||
assert version == "v2.0" # Updated to v2.0 with Blueprint alignment
|
||||
|
||||
def test_list_prompt_types(self, config_dir):
|
||||
"""Test listing prompt types."""
|
||||
manager = PromptManager(config_dir / "prompts")
|
||||
types = manager.list_prompt_types()
|
||||
|
||||
assert "call_analysis" in types
|
||||
|
||||
def test_caching(self, config_dir):
|
||||
"""Test that prompts are cached."""
|
||||
manager = PromptManager(config_dir / "prompts")
|
||||
|
||||
template1 = manager.load("call_analysis", "v1.0")
|
||||
template2 = manager.load("call_analysis", "v1.0")
|
||||
|
||||
assert template1 is template2 # Same object
|
||||
|
||||
|
||||
class TestFormatFunctions:
|
||||
"""Tests for formatting helper functions."""
|
||||
|
||||
def test_format_events_empty(self):
|
||||
"""Test formatting with no events."""
|
||||
result = format_events_for_prompt([])
|
||||
assert "No significant events" in result
|
||||
|
||||
def test_format_events_with_events(self):
|
||||
"""Test formatting with events."""
|
||||
events = [
|
||||
Event(
|
||||
event_type=EventType.HOLD_START,
|
||||
start_time=10.0,
|
||||
),
|
||||
Event(
|
||||
event_type=EventType.SILENCE,
|
||||
start_time=30.0,
|
||||
duration_sec=8.0,
|
||||
),
|
||||
]
|
||||
|
||||
result = format_events_for_prompt(events)
|
||||
|
||||
assert "HOLD_START" in result
|
||||
assert "10.0s" in result
|
||||
assert "SILENCE" in result
|
||||
|
||||
def test_format_transcript_basic(self):
|
||||
"""Test basic transcript formatting."""
|
||||
turns = [
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Hello",
|
||||
start_time=0.0,
|
||||
end_time=1.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="customer",
|
||||
text="Hi there",
|
||||
start_time=1.5,
|
||||
end_time=3.0,
|
||||
),
|
||||
]
|
||||
|
||||
result = format_transcript_for_prompt(turns)
|
||||
|
||||
assert "AGENT" in result
|
||||
assert "Hello" in result
|
||||
assert "CUSTOMER" in result
|
||||
assert "Hi there" in result
|
||||
|
||||
def test_format_transcript_truncation(self):
|
||||
"""Test transcript truncation."""
|
||||
turns = [
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="A" * 5000, # Long text
|
||||
start_time=0.0,
|
||||
end_time=10.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="customer",
|
||||
text="B" * 5000, # Long text
|
||||
start_time=10.0,
|
||||
end_time=20.0,
|
||||
),
|
||||
]
|
||||
|
||||
result = format_transcript_for_prompt(turns, max_chars=6000)
|
||||
|
||||
assert "truncated" in result
|
||||
assert len(result) < 8000
|
||||
|
||||
|
||||
class TestAnalyzerValidation:
|
||||
"""Tests for analyzer validation logic."""
|
||||
|
||||
def test_evidence_required(self):
|
||||
"""Test that evidence is required for RCA labels."""
|
||||
from src.models.call_analysis import EvidenceSpan, RCALabel
|
||||
|
||||
# Valid: with evidence
|
||||
valid = RCALabel(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
confidence=0.9,
|
||||
evidence_spans=[
|
||||
EvidenceSpan(
|
||||
text="Es demasiado caro",
|
||||
start_time=10.0,
|
||||
end_time=12.0,
|
||||
)
|
||||
],
|
||||
)
|
||||
assert valid.driver_code == "PRICE_TOO_HIGH"
|
||||
|
||||
# Invalid: without evidence
|
||||
with pytest.raises(ValueError):
|
||||
RCALabel(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
confidence=0.9,
|
||||
evidence_spans=[], # Empty
|
||||
)
|
||||
|
||||
def test_confidence_bounds(self):
|
||||
"""Test confidence must be 0-1."""
|
||||
from src.models.call_analysis import EvidenceSpan, RCALabel
|
||||
|
||||
evidence = [EvidenceSpan(text="test", start_time=0, end_time=1)]
|
||||
|
||||
# Valid
|
||||
valid = RCALabel(
|
||||
driver_code="TEST",
|
||||
confidence=0.5,
|
||||
evidence_spans=evidence,
|
||||
)
|
||||
assert valid.confidence == 0.5
|
||||
|
||||
# Invalid: > 1
|
||||
with pytest.raises(ValueError):
|
||||
RCALabel(
|
||||
driver_code="TEST",
|
||||
confidence=1.5,
|
||||
evidence_spans=evidence,
|
||||
)
|
||||
|
||||
def test_emergent_requires_proposed_label(self):
|
||||
"""Test OTHER_EMERGENT requires proposed_label."""
|
||||
from src.models.call_analysis import EvidenceSpan, RCALabel
|
||||
|
||||
evidence = [EvidenceSpan(text="test", start_time=0, end_time=1)]
|
||||
|
||||
# Valid: with proposed_label
|
||||
valid = RCALabel(
|
||||
driver_code="OTHER_EMERGENT",
|
||||
confidence=0.7,
|
||||
evidence_spans=evidence,
|
||||
proposed_label="NEW_PATTERN",
|
||||
)
|
||||
assert valid.proposed_label == "NEW_PATTERN"
|
||||
|
||||
# Invalid: without proposed_label
|
||||
with pytest.raises(ValueError):
|
||||
RCALabel(
|
||||
driver_code="OTHER_EMERGENT",
|
||||
confidence=0.7,
|
||||
evidence_spans=evidence,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_dir(project_root):
|
||||
"""Return the config directory."""
|
||||
return project_root / "config"
|
||||
414
tests/unit/test_pipeline.py
Normal file
414
tests/unit/test_pipeline.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""
|
||||
CXInsights - Pipeline Tests
|
||||
|
||||
Tests for the end-to-end pipeline and exports.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from src.pipeline.models import (
|
||||
PipelineConfig,
|
||||
PipelineManifest,
|
||||
PipelineStage,
|
||||
StageManifest,
|
||||
StageStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestStageManifest:
|
||||
"""Tests for StageManifest."""
|
||||
|
||||
def test_create_stage_manifest(self):
|
||||
"""Test creating a stage manifest."""
|
||||
manifest = StageManifest(stage=PipelineStage.TRANSCRIPTION)
|
||||
|
||||
assert manifest.stage == PipelineStage.TRANSCRIPTION
|
||||
assert manifest.status == StageStatus.PENDING
|
||||
assert manifest.total_items == 0
|
||||
|
||||
def test_success_rate(self):
|
||||
"""Test success rate calculation."""
|
||||
manifest = StageManifest(
|
||||
stage=PipelineStage.INFERENCE,
|
||||
total_items=100,
|
||||
processed_items=90,
|
||||
failed_items=10,
|
||||
)
|
||||
|
||||
assert manifest.success_rate == 0.8
|
||||
|
||||
def test_success_rate_zero_items(self):
|
||||
"""Test success rate with zero items."""
|
||||
manifest = StageManifest(stage=PipelineStage.INFERENCE)
|
||||
assert manifest.success_rate == 0.0
|
||||
|
||||
def test_duration(self):
|
||||
"""Test duration calculation."""
|
||||
start = datetime(2024, 1, 1, 10, 0, 0)
|
||||
end = datetime(2024, 1, 1, 10, 5, 30)
|
||||
|
||||
manifest = StageManifest(
|
||||
stage=PipelineStage.INFERENCE,
|
||||
started_at=start,
|
||||
completed_at=end,
|
||||
)
|
||||
|
||||
assert manifest.duration_sec == 330.0 # 5 min 30 sec
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test serialization."""
|
||||
manifest = StageManifest(
|
||||
stage=PipelineStage.TRANSCRIPTION,
|
||||
status=StageStatus.COMPLETED,
|
||||
total_items=10,
|
||||
processed_items=10,
|
||||
)
|
||||
|
||||
data = manifest.to_dict()
|
||||
|
||||
assert data["stage"] == "transcription"
|
||||
assert data["status"] == "completed"
|
||||
assert data["total_items"] == 10
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test deserialization."""
|
||||
data = {
|
||||
"stage": "inference",
|
||||
"status": "running",
|
||||
"started_at": "2024-01-01T10:00:00",
|
||||
"completed_at": None,
|
||||
"total_items": 50,
|
||||
"processed_items": 25,
|
||||
"failed_items": 0,
|
||||
"skipped_items": 0,
|
||||
"errors": [],
|
||||
"output_dir": None,
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
manifest = StageManifest.from_dict(data)
|
||||
|
||||
assert manifest.stage == PipelineStage.INFERENCE
|
||||
assert manifest.status == StageStatus.RUNNING
|
||||
assert manifest.total_items == 50
|
||||
|
||||
|
||||
class TestPipelineManifest:
|
||||
"""Tests for PipelineManifest."""
|
||||
|
||||
def test_create_manifest(self):
|
||||
"""Test creating pipeline manifest."""
|
||||
manifest = PipelineManifest(batch_id="test_batch")
|
||||
|
||||
assert manifest.batch_id == "test_batch"
|
||||
assert manifest.status == StageStatus.PENDING
|
||||
assert len(manifest.stages) == len(PipelineStage)
|
||||
|
||||
def test_mark_stage_started(self):
|
||||
"""Test marking stage as started."""
|
||||
manifest = PipelineManifest(batch_id="test")
|
||||
|
||||
manifest.mark_stage_started(PipelineStage.TRANSCRIPTION, total_items=100)
|
||||
|
||||
stage = manifest.stages[PipelineStage.TRANSCRIPTION]
|
||||
assert stage.status == StageStatus.RUNNING
|
||||
assert stage.total_items == 100
|
||||
assert stage.started_at is not None
|
||||
assert manifest.current_stage == PipelineStage.TRANSCRIPTION
|
||||
|
||||
def test_mark_stage_completed(self):
|
||||
"""Test marking stage as completed."""
|
||||
manifest = PipelineManifest(batch_id="test")
|
||||
manifest.mark_stage_started(PipelineStage.TRANSCRIPTION, 100)
|
||||
manifest.mark_stage_completed(
|
||||
PipelineStage.TRANSCRIPTION,
|
||||
processed=95,
|
||||
failed=5,
|
||||
metadata={"key": "value"},
|
||||
)
|
||||
|
||||
stage = manifest.stages[PipelineStage.TRANSCRIPTION]
|
||||
assert stage.status == StageStatus.COMPLETED
|
||||
assert stage.processed_items == 95
|
||||
assert stage.failed_items == 5
|
||||
assert stage.metadata["key"] == "value"
|
||||
|
||||
def test_mark_stage_failed(self):
|
||||
"""Test marking stage as failed."""
|
||||
manifest = PipelineManifest(batch_id="test")
|
||||
manifest.mark_stage_started(PipelineStage.INFERENCE, 50)
|
||||
manifest.mark_stage_failed(PipelineStage.INFERENCE, "API error")
|
||||
|
||||
stage = manifest.stages[PipelineStage.INFERENCE]
|
||||
assert stage.status == StageStatus.FAILED
|
||||
assert len(stage.errors) == 1
|
||||
assert "API error" in stage.errors[0]["error"]
|
||||
assert manifest.status == StageStatus.FAILED
|
||||
|
||||
def test_can_resume_from(self):
|
||||
"""Test resume capability check."""
|
||||
manifest = PipelineManifest(batch_id="test")
|
||||
|
||||
# Mark first two stages as complete
|
||||
manifest.stages[PipelineStage.TRANSCRIPTION].status = StageStatus.COMPLETED
|
||||
manifest.stages[PipelineStage.FEATURE_EXTRACTION].status = StageStatus.COMPLETED
|
||||
|
||||
# Can resume from compression
|
||||
assert manifest.can_resume_from(PipelineStage.COMPRESSION) is True
|
||||
|
||||
# Cannot resume from inference (compression not done)
|
||||
assert manifest.can_resume_from(PipelineStage.INFERENCE) is False
|
||||
|
||||
def test_get_resume_stage(self):
|
||||
"""Test getting resume stage."""
|
||||
manifest = PipelineManifest(batch_id="test")
|
||||
|
||||
# All pending - resume from first
|
||||
assert manifest.get_resume_stage() == PipelineStage.TRANSCRIPTION
|
||||
|
||||
# Some complete
|
||||
manifest.stages[PipelineStage.TRANSCRIPTION].status = StageStatus.COMPLETED
|
||||
manifest.stages[PipelineStage.FEATURE_EXTRACTION].status = StageStatus.COMPLETED
|
||||
assert manifest.get_resume_stage() == PipelineStage.COMPRESSION
|
||||
|
||||
def test_is_complete(self):
|
||||
"""Test completion check."""
|
||||
manifest = PipelineManifest(batch_id="test")
|
||||
|
||||
assert manifest.is_complete is False
|
||||
|
||||
for stage in PipelineStage:
|
||||
manifest.stages[stage].status = StageStatus.COMPLETED
|
||||
|
||||
assert manifest.is_complete is True
|
||||
|
||||
def test_save_and_load(self):
|
||||
"""Test manifest persistence."""
|
||||
manifest = PipelineManifest(
|
||||
batch_id="persist_test",
|
||||
total_audio_files=100,
|
||||
)
|
||||
manifest.mark_stage_started(PipelineStage.TRANSCRIPTION, 100)
|
||||
manifest.mark_stage_completed(PipelineStage.TRANSCRIPTION, 100)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = Path(tmp) / "manifest.json"
|
||||
manifest.save(path)
|
||||
|
||||
loaded = PipelineManifest.load(path)
|
||||
|
||||
assert loaded.batch_id == "persist_test"
|
||||
assert loaded.total_audio_files == 100
|
||||
assert loaded.stages[PipelineStage.TRANSCRIPTION].status == StageStatus.COMPLETED
|
||||
|
||||
|
||||
class TestPipelineConfig:
|
||||
"""Tests for PipelineConfig."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration."""
|
||||
config = PipelineConfig()
|
||||
|
||||
assert config.inference_model == "gpt-4o-mini"
|
||||
assert config.use_compression is True
|
||||
assert "json" in config.export_formats
|
||||
assert "excel" in config.export_formats
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom configuration."""
|
||||
config = PipelineConfig(
|
||||
inference_model="gpt-4o",
|
||||
use_compression=False,
|
||||
export_formats=["json", "pdf"],
|
||||
)
|
||||
|
||||
assert config.inference_model == "gpt-4o"
|
||||
assert config.use_compression is False
|
||||
assert "pdf" in config.export_formats
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test config serialization."""
|
||||
config = PipelineConfig()
|
||||
data = config.to_dict()
|
||||
|
||||
assert "inference_model" in data
|
||||
assert "export_formats" in data
|
||||
assert isinstance(data["export_formats"], list)
|
||||
|
||||
|
||||
class TestPipelineStages:
|
||||
"""Tests for pipeline stage enum."""
|
||||
|
||||
def test_all_stages_defined(self):
|
||||
"""Test that all expected stages exist."""
|
||||
expected = [
|
||||
"transcription",
|
||||
"feature_extraction",
|
||||
"compression",
|
||||
"inference",
|
||||
"aggregation",
|
||||
"export",
|
||||
]
|
||||
|
||||
for stage_name in expected:
|
||||
assert PipelineStage(stage_name) is not None
|
||||
|
||||
def test_stage_order(self):
|
||||
"""Test that stages are in correct order."""
|
||||
stages = list(PipelineStage)
|
||||
|
||||
assert stages[0] == PipelineStage.TRANSCRIPTION
|
||||
assert stages[-1] == PipelineStage.EXPORT
|
||||
|
||||
|
||||
class TestExports:
|
||||
"""Tests for export functions."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_aggregation(self):
|
||||
"""Create sample aggregation for export tests."""
|
||||
from src.aggregation.models import (
|
||||
BatchAggregation,
|
||||
DriverFrequency,
|
||||
DriverSeverity,
|
||||
ImpactLevel,
|
||||
RCATree,
|
||||
)
|
||||
|
||||
return BatchAggregation(
|
||||
batch_id="export_test",
|
||||
total_calls_processed=100,
|
||||
successful_analyses=95,
|
||||
failed_analyses=5,
|
||||
lost_sales_frequencies=[
|
||||
DriverFrequency(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
category="lost_sales",
|
||||
total_occurrences=30,
|
||||
calls_affected=25,
|
||||
total_calls_in_batch=100,
|
||||
occurrence_rate=0.30,
|
||||
call_rate=0.25,
|
||||
avg_confidence=0.85,
|
||||
min_confidence=0.7,
|
||||
max_confidence=0.95,
|
||||
),
|
||||
],
|
||||
poor_cx_frequencies=[
|
||||
DriverFrequency(
|
||||
driver_code="LONG_HOLD",
|
||||
category="poor_cx",
|
||||
total_occurrences=20,
|
||||
calls_affected=20,
|
||||
total_calls_in_batch=100,
|
||||
occurrence_rate=0.20,
|
||||
call_rate=0.20,
|
||||
avg_confidence=0.9,
|
||||
min_confidence=0.8,
|
||||
max_confidence=0.95,
|
||||
),
|
||||
],
|
||||
lost_sales_severities=[
|
||||
DriverSeverity(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
category="lost_sales",
|
||||
base_severity=0.8,
|
||||
frequency_factor=0.5,
|
||||
confidence_factor=0.85,
|
||||
co_occurrence_factor=0.2,
|
||||
severity_score=65.0,
|
||||
impact_level=ImpactLevel.HIGH,
|
||||
),
|
||||
],
|
||||
poor_cx_severities=[
|
||||
DriverSeverity(
|
||||
driver_code="LONG_HOLD",
|
||||
category="poor_cx",
|
||||
base_severity=0.7,
|
||||
frequency_factor=0.4,
|
||||
confidence_factor=0.9,
|
||||
co_occurrence_factor=0.1,
|
||||
severity_score=55.0,
|
||||
impact_level=ImpactLevel.HIGH,
|
||||
),
|
||||
],
|
||||
rca_tree=RCATree(
|
||||
batch_id="export_test",
|
||||
total_calls=100,
|
||||
calls_with_lost_sales=25,
|
||||
calls_with_poor_cx=20,
|
||||
calls_with_both=5,
|
||||
top_lost_sales_drivers=["PRICE_TOO_HIGH"],
|
||||
top_poor_cx_drivers=["LONG_HOLD"],
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_analyses(self):
|
||||
"""Create sample analyses for export tests."""
|
||||
from src.models.call_analysis import (
|
||||
CallAnalysis,
|
||||
CallOutcome,
|
||||
ObservedFeatures,
|
||||
ProcessingStatus,
|
||||
Traceability,
|
||||
)
|
||||
|
||||
return [
|
||||
CallAnalysis(
|
||||
call_id="CALL001",
|
||||
batch_id="export_test",
|
||||
status=ProcessingStatus.SUCCESS,
|
||||
observed=ObservedFeatures(audio_duration_sec=60),
|
||||
outcome=CallOutcome.SALE_LOST,
|
||||
lost_sales_drivers=[],
|
||||
poor_cx_drivers=[],
|
||||
traceability=Traceability(
|
||||
schema_version="1.0",
|
||||
prompt_version="v1.0",
|
||||
model_id="gpt-4o-mini",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
def test_json_export(self, sample_aggregation, sample_analyses):
|
||||
"""Test JSON export."""
|
||||
from src.exports.json_export import export_to_json
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
output_dir = Path(tmp)
|
||||
result = export_to_json(
|
||||
"test_batch",
|
||||
sample_aggregation,
|
||||
sample_analyses,
|
||||
output_dir,
|
||||
)
|
||||
|
||||
assert result.exists()
|
||||
assert result.name == "summary.json"
|
||||
|
||||
# Verify content
|
||||
with open(result) as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert data["batch_id"] == "test_batch"
|
||||
assert "summary" in data
|
||||
assert "lost_sales" in data
|
||||
assert "poor_cx" in data
|
||||
|
||||
def test_pdf_export_html_fallback(self, sample_aggregation):
|
||||
"""Test PDF export falls back to HTML."""
|
||||
from src.exports.pdf_export import export_to_pdf
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
output_dir = Path(tmp)
|
||||
result = export_to_pdf("test_batch", sample_aggregation, output_dir)
|
||||
|
||||
assert result.exists()
|
||||
# Should be HTML if weasyprint not installed
|
||||
assert result.suffix in [".pdf", ".html"]
|
||||
322
tests/unit/test_transcription.py
Normal file
322
tests/unit/test_transcription.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""
|
||||
CXInsights - Transcription Module Tests
|
||||
|
||||
Unit tests for transcription models and utilities.
|
||||
Does NOT test actual API calls (those are in integration tests).
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from src.transcription.models import (
|
||||
AudioMetadata,
|
||||
SpeakerTurn,
|
||||
Transcript,
|
||||
TranscriptMetadata,
|
||||
TranscriptionConfig,
|
||||
TranscriptionError,
|
||||
TranscriptionResult,
|
||||
TranscriptionStatus,
|
||||
)
|
||||
|
||||
|
||||
class TestSpeakerTurn:
|
||||
"""Tests for SpeakerTurn model."""
|
||||
|
||||
def test_create_valid_turn(self):
|
||||
"""Test creating a valid speaker turn."""
|
||||
turn = SpeakerTurn(
|
||||
speaker="A",
|
||||
text="Hola, buenos días",
|
||||
start_time=0.0,
|
||||
end_time=2.5,
|
||||
confidence=0.95,
|
||||
)
|
||||
|
||||
assert turn.speaker == "A"
|
||||
assert turn.text == "Hola, buenos días"
|
||||
assert turn.start_time == 0.0
|
||||
assert turn.end_time == 2.5
|
||||
assert turn.confidence == 0.95
|
||||
|
||||
def test_duration_computed(self):
|
||||
"""Test that duration is computed correctly."""
|
||||
turn = SpeakerTurn(
|
||||
speaker="A",
|
||||
text="Test",
|
||||
start_time=10.0,
|
||||
end_time=15.5,
|
||||
)
|
||||
|
||||
assert turn.duration_sec == 5.5
|
||||
|
||||
def test_word_count_computed(self):
|
||||
"""Test that word count is computed correctly."""
|
||||
turn = SpeakerTurn(
|
||||
speaker="A",
|
||||
text="Esto es una prueba de conteo de palabras",
|
||||
start_time=0.0,
|
||||
end_time=5.0,
|
||||
)
|
||||
|
||||
assert turn.word_count == 7
|
||||
|
||||
def test_empty_text_word_count(self):
|
||||
"""Test word count with empty text."""
|
||||
turn = SpeakerTurn(
|
||||
speaker="A",
|
||||
text="",
|
||||
start_time=0.0,
|
||||
end_time=1.0,
|
||||
)
|
||||
|
||||
assert turn.word_count == 1 # Empty string splits to ['']
|
||||
|
||||
def test_confidence_optional(self):
|
||||
"""Test that confidence is optional."""
|
||||
turn = SpeakerTurn(
|
||||
speaker="A",
|
||||
text="Test",
|
||||
start_time=0.0,
|
||||
end_time=1.0,
|
||||
)
|
||||
|
||||
assert turn.confidence is None
|
||||
|
||||
|
||||
class TestTranscriptMetadata:
|
||||
"""Tests for TranscriptMetadata model."""
|
||||
|
||||
def test_create_metadata(self):
|
||||
"""Test creating transcript metadata."""
|
||||
metadata = TranscriptMetadata(
|
||||
audio_duration_sec=420.5,
|
||||
audio_file="call_001.mp3",
|
||||
language="es",
|
||||
provider="assemblyai",
|
||||
job_id="abc123",
|
||||
)
|
||||
|
||||
assert metadata.audio_duration_sec == 420.5
|
||||
assert metadata.audio_file == "call_001.mp3"
|
||||
assert metadata.language == "es"
|
||||
assert metadata.provider == "assemblyai"
|
||||
assert metadata.job_id == "abc123"
|
||||
|
||||
def test_created_at_default(self):
|
||||
"""Test that created_at defaults to now."""
|
||||
metadata = TranscriptMetadata(
|
||||
audio_duration_sec=100.0,
|
||||
audio_file="test.mp3",
|
||||
provider="assemblyai",
|
||||
)
|
||||
|
||||
assert metadata.created_at is not None
|
||||
assert isinstance(metadata.created_at, datetime)
|
||||
|
||||
|
||||
class TestTranscript:
|
||||
"""Tests for Transcript model."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_transcript(self):
|
||||
"""Create a sample transcript for testing."""
|
||||
return Transcript(
|
||||
call_id="CALL001",
|
||||
turns=[
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Buenos días, ¿en qué puedo ayudarle?",
|
||||
start_time=0.0,
|
||||
end_time=3.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="customer",
|
||||
text="Quiero cancelar mi servicio",
|
||||
start_time=3.5,
|
||||
end_time=6.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Entiendo, ¿me puede indicar el motivo?",
|
||||
start_time=6.5,
|
||||
end_time=9.0,
|
||||
),
|
||||
],
|
||||
metadata=TranscriptMetadata(
|
||||
audio_duration_sec=420.0,
|
||||
audio_file="CALL001.mp3",
|
||||
provider="assemblyai",
|
||||
speaker_count=2,
|
||||
),
|
||||
)
|
||||
|
||||
def test_total_turns(self, sample_transcript):
|
||||
"""Test total turns count."""
|
||||
assert sample_transcript.total_turns == 3
|
||||
|
||||
def test_total_words(self, sample_transcript):
|
||||
"""Test total words count."""
|
||||
# "Buenos días, ¿en qué puedo ayudarle?" = 6 words
|
||||
# "Quiero cancelar mi servicio" = 4 words
|
||||
# "Entiendo, ¿me puede indicar el motivo?" = 6 words
|
||||
assert sample_transcript.total_words == 16
|
||||
|
||||
def test_get_full_text(self, sample_transcript):
|
||||
"""Test getting full text."""
|
||||
full_text = sample_transcript.get_full_text()
|
||||
assert "Buenos días" in full_text
|
||||
assert "cancelar mi servicio" in full_text
|
||||
|
||||
def test_get_speaker_text(self, sample_transcript):
|
||||
"""Test getting text for a specific speaker."""
|
||||
agent_text = sample_transcript.get_speaker_text("agent")
|
||||
customer_text = sample_transcript.get_speaker_text("customer")
|
||||
|
||||
assert "Buenos días" in agent_text
|
||||
assert "cancelar" not in agent_text
|
||||
assert "cancelar mi servicio" in customer_text
|
||||
|
||||
def test_get_speakers(self, sample_transcript):
|
||||
"""Test getting unique speakers."""
|
||||
speakers = sample_transcript.get_speakers()
|
||||
|
||||
assert len(speakers) == 2
|
||||
assert "agent" in speakers
|
||||
assert "customer" in speakers
|
||||
|
||||
|
||||
class TestTranscriptionResult:
|
||||
"""Tests for TranscriptionResult model."""
|
||||
|
||||
def test_success_result(self):
|
||||
"""Test creating a successful result."""
|
||||
transcript = Transcript(
|
||||
call_id="CALL001",
|
||||
turns=[],
|
||||
metadata=TranscriptMetadata(
|
||||
audio_duration_sec=100.0,
|
||||
audio_file="test.mp3",
|
||||
provider="assemblyai",
|
||||
),
|
||||
)
|
||||
|
||||
result = TranscriptionResult.success(
|
||||
call_id="CALL001",
|
||||
audio_path=Path("test.mp3"),
|
||||
transcript=transcript,
|
||||
)
|
||||
|
||||
assert result.status == TranscriptionStatus.COMPLETED
|
||||
assert result.is_success is True
|
||||
assert result.transcript is not None
|
||||
assert result.error is None
|
||||
|
||||
def test_failure_result(self):
|
||||
"""Test creating a failed result."""
|
||||
result = TranscriptionResult.failure(
|
||||
call_id="CALL001",
|
||||
audio_path=Path("test.mp3"),
|
||||
error=TranscriptionError.API_ERROR,
|
||||
error_message="Rate limit exceeded",
|
||||
)
|
||||
|
||||
assert result.status == TranscriptionStatus.FAILED
|
||||
assert result.is_success is False
|
||||
assert result.transcript is None
|
||||
assert result.error == TranscriptionError.API_ERROR
|
||||
assert result.error_message == "Rate limit exceeded"
|
||||
|
||||
def test_processing_time_computed(self):
|
||||
"""Test processing time calculation."""
|
||||
result = TranscriptionResult(
|
||||
call_id="CALL001",
|
||||
audio_path="test.mp3",
|
||||
status=TranscriptionStatus.COMPLETED,
|
||||
started_at=datetime(2024, 1, 1, 12, 0, 0),
|
||||
completed_at=datetime(2024, 1, 1, 12, 0, 30),
|
||||
)
|
||||
|
||||
assert result.processing_time_sec == 30.0
|
||||
|
||||
|
||||
class TestAudioMetadata:
|
||||
"""Tests for AudioMetadata model."""
|
||||
|
||||
def test_create_metadata(self):
|
||||
"""Test creating audio metadata."""
|
||||
metadata = AudioMetadata(
|
||||
file_path="/data/audio/call.mp3",
|
||||
file_size_bytes=5242880, # 5 MB
|
||||
duration_sec=420.0, # 7 minutes
|
||||
format="mp3",
|
||||
codec="mp3",
|
||||
sample_rate=44100,
|
||||
channels=2,
|
||||
bit_rate=128000,
|
||||
)
|
||||
|
||||
assert metadata.file_path == "/data/audio/call.mp3"
|
||||
assert metadata.duration_sec == 420.0
|
||||
assert metadata.format == "mp3"
|
||||
|
||||
def test_duration_minutes(self):
|
||||
"""Test duration in minutes conversion."""
|
||||
metadata = AudioMetadata(
|
||||
file_path="test.mp3",
|
||||
file_size_bytes=1000000,
|
||||
duration_sec=420.0,
|
||||
format="mp3",
|
||||
)
|
||||
|
||||
assert metadata.duration_minutes == 7.0
|
||||
|
||||
def test_file_size_mb(self):
|
||||
"""Test file size in MB conversion."""
|
||||
metadata = AudioMetadata(
|
||||
file_path="test.mp3",
|
||||
file_size_bytes=5242880, # 5 MB
|
||||
duration_sec=100.0,
|
||||
format="mp3",
|
||||
)
|
||||
|
||||
assert metadata.file_size_mb == 5.0
|
||||
|
||||
|
||||
class TestTranscriptionConfig:
|
||||
"""Tests for TranscriptionConfig model."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration values."""
|
||||
config = TranscriptionConfig()
|
||||
|
||||
assert config.language_code == "es"
|
||||
assert config.speaker_labels is True
|
||||
assert config.punctuate is True
|
||||
assert config.format_text is True
|
||||
assert config.auto_chapters is False
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom configuration."""
|
||||
config = TranscriptionConfig(
|
||||
language_code="en",
|
||||
speaker_labels=False,
|
||||
auto_chapters=True,
|
||||
)
|
||||
|
||||
assert config.language_code == "en"
|
||||
assert config.speaker_labels is False
|
||||
assert config.auto_chapters is True
|
||||
|
||||
|
||||
class TestTranscriptionError:
|
||||
"""Tests for TranscriptionError enum."""
|
||||
|
||||
def test_error_values(self):
|
||||
"""Test that all error values are strings."""
|
||||
assert TranscriptionError.FILE_NOT_FOUND == "FILE_NOT_FOUND"
|
||||
assert TranscriptionError.API_ERROR == "API_ERROR"
|
||||
assert TranscriptionError.RATE_LIMITED == "RATE_LIMITED"
|
||||
assert TranscriptionError.TIMEOUT == "TIMEOUT"
|
||||
Reference in New Issue
Block a user