feat: Add Streamlit dashboard with Blueprint compliance (v2.1.0)
Dashboard Features: - 8 navigation sections: Overview, Outcomes, Poor CX, FCR, Churn, Agent, Call Explorer, Export - Beyond Brand Identity styling (colors #6D84E3, Outfit font) - RCA Sankey diagram (Driver → Outcome → Churn Risk flow) - Correlation heatmaps (driver co-occurrence, driver-outcome) - Outcome Deep Dive (root causes, correlation, duration analysis) - Export functionality (Excel, HTML, JSON) Blueprint Compliance: - FCR: 4 categories (Primera Llamada/Rellamada × Sin/Con Riesgo de Fuga) - Churn: Binary view (Sin Riesgo de Fuga / En Riesgo de Fuga) - Agent: Talento Para Replicar / Oportunidades de Mejora - Fixed FCR rate calculation (only FIRST_CALL counts as success) Technical: - Streamlit + Plotly for interactive visualizations - Light theme configuration (.streamlit/config.toml) - Fixed Plotly colorbar titlefont deprecation Documentation: - Updated PROJECT_CONTEXT.md, TODO.md, CHANGELOG.md - Added 4 new technical decisions (TD-014 to TD-017) - Created TROUBLESHOOTING.md with 10 common issues Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
393
tests/unit/test_inference.py
Normal file
393
tests/unit/test_inference.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
CXInsights - Inference Module Tests
|
||||
|
||||
Tests for LLM client, prompt manager, and analyzer.
|
||||
Uses mocks for LLM calls to avoid API costs.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.inference.client import LLMClient, LLMClientConfig, LLMResponse
|
||||
from src.inference.prompt_manager import (
|
||||
PromptManager,
|
||||
PromptTemplate,
|
||||
format_events_for_prompt,
|
||||
format_transcript_for_prompt,
|
||||
)
|
||||
from src.models.call_analysis import Event, EventType
|
||||
from src.transcription.models import SpeakerTurn, Transcript, TranscriptMetadata
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
"""Tests for LLMResponse."""
|
||||
|
||||
def test_cost_estimate(self):
|
||||
"""Test cost estimation based on tokens."""
|
||||
response = LLMResponse(
|
||||
content="test",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
total_tokens=1500,
|
||||
)
|
||||
|
||||
# GPT-4o-mini: $0.15/1M input, $0.60/1M output
|
||||
expected = (1000 / 1_000_000) * 0.15 + (500 / 1_000_000) * 0.60
|
||||
assert abs(response.cost_estimate_usd - expected) < 0.0001
|
||||
|
||||
def test_success_flag(self):
|
||||
"""Test success flag."""
|
||||
success = LLMResponse(content="test", success=True)
|
||||
failure = LLMResponse(content="", success=False, error="API error")
|
||||
|
||||
assert success.success is True
|
||||
assert failure.success is False
|
||||
|
||||
def test_parsed_json(self):
|
||||
"""Test parsed JSON storage."""
|
||||
response = LLMResponse(
|
||||
content='{"key": "value"}',
|
||||
parsed_json={"key": "value"},
|
||||
)
|
||||
|
||||
assert response.parsed_json == {"key": "value"}
|
||||
|
||||
|
||||
class TestLLMClientConfig:
|
||||
"""Tests for LLMClientConfig."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default configuration."""
|
||||
config = LLMClientConfig()
|
||||
|
||||
assert config.model == "gpt-4o-mini"
|
||||
assert config.temperature == 0.1
|
||||
assert config.max_tokens == 4000
|
||||
assert config.json_mode is True
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom configuration."""
|
||||
config = LLMClientConfig(
|
||||
model="gpt-4o",
|
||||
temperature=0.5,
|
||||
max_tokens=8000,
|
||||
)
|
||||
|
||||
assert config.model == "gpt-4o"
|
||||
assert config.temperature == 0.5
|
||||
assert config.max_tokens == 8000
|
||||
|
||||
|
||||
class TestLLMClient:
|
||||
"""Tests for LLMClient."""
|
||||
|
||||
def test_requires_api_key(self):
|
||||
"""Test that API key is required."""
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
with pytest.raises(ValueError, match="API key required"):
|
||||
LLMClient()
|
||||
|
||||
def test_parse_json_valid(self):
|
||||
"""Test JSON parsing with valid JSON."""
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
client = LLMClient()
|
||||
result = client._parse_json('{"key": "value"}')
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_parse_json_with_markdown(self):
|
||||
"""Test JSON parsing with markdown code blocks."""
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
client = LLMClient()
|
||||
content = '```json\n{"key": "value"}\n```'
|
||||
result = client._parse_json(content)
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_parse_json_extract_from_text(self):
|
||||
"""Test JSON extraction from surrounding text."""
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
client = LLMClient()
|
||||
content = 'Here is the result: {"key": "value"} end.'
|
||||
result = client._parse_json(content)
|
||||
assert result == {"key": "value"}
|
||||
|
||||
def test_parse_json_invalid(self):
|
||||
"""Test JSON parsing with invalid JSON."""
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
client = LLMClient()
|
||||
result = client._parse_json("not json at all")
|
||||
assert result is None
|
||||
|
||||
def test_usage_stats_tracking(self):
|
||||
"""Test usage statistics tracking."""
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
client = LLMClient()
|
||||
|
||||
# Initially zero
|
||||
stats = client.get_usage_stats()
|
||||
assert stats["total_calls"] == 0
|
||||
assert stats["total_tokens"] == 0
|
||||
|
||||
def test_reset_usage_stats(self):
|
||||
"""Test resetting usage statistics."""
|
||||
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
|
||||
client = LLMClient()
|
||||
client._total_calls = 10
|
||||
client._total_tokens = 5000
|
||||
|
||||
client.reset_usage_stats()
|
||||
|
||||
stats = client.get_usage_stats()
|
||||
assert stats["total_calls"] == 0
|
||||
assert stats["total_tokens"] == 0
|
||||
|
||||
|
||||
class TestPromptTemplate:
|
||||
"""Tests for PromptTemplate."""
|
||||
|
||||
def test_render_basic(self):
|
||||
"""Test basic template rendering."""
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="v1.0",
|
||||
system="You are analyzing call $call_id",
|
||||
user="Transcript: $transcript",
|
||||
)
|
||||
|
||||
system, user = template.render(
|
||||
call_id="CALL001",
|
||||
transcript="Hello world",
|
||||
)
|
||||
|
||||
assert "CALL001" in system
|
||||
assert "Hello world" in user
|
||||
|
||||
def test_render_missing_var(self):
|
||||
"""Test rendering with missing variable (safe_substitute)."""
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="v1.0",
|
||||
system="Call $call_id in $queue",
|
||||
user="Text",
|
||||
)
|
||||
|
||||
system, user = template.render(call_id="CALL001")
|
||||
# safe_substitute leaves $queue as-is
|
||||
assert "$queue" in system
|
||||
|
||||
def test_to_messages(self):
|
||||
"""Test message list generation."""
|
||||
template = PromptTemplate(
|
||||
name="test",
|
||||
version="v1.0",
|
||||
system="System message",
|
||||
user="User message",
|
||||
)
|
||||
|
||||
messages = template.to_messages()
|
||||
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[1]["role"] == "user"
|
||||
|
||||
|
||||
class TestPromptManager:
|
||||
"""Tests for PromptManager."""
|
||||
|
||||
def test_load_call_analysis_prompt(self, config_dir):
|
||||
"""Test loading call analysis prompt."""
|
||||
manager = PromptManager(config_dir / "prompts")
|
||||
template = manager.load("call_analysis", "v1.0")
|
||||
|
||||
assert template.name == "call_analysis"
|
||||
assert template.version == "v1.0"
|
||||
assert len(template.system) > 0
|
||||
assert len(template.user) > 0
|
||||
|
||||
def test_load_nonexistent_prompt(self, config_dir):
|
||||
"""Test loading non-existent prompt."""
|
||||
manager = PromptManager(config_dir / "prompts")
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
manager.load("nonexistent", "v1.0")
|
||||
|
||||
def test_get_active_version(self, config_dir):
|
||||
"""Test getting active version."""
|
||||
manager = PromptManager(config_dir / "prompts")
|
||||
version = manager.get_active_version("call_analysis")
|
||||
|
||||
assert version == "v2.0" # Updated to v2.0 with Blueprint alignment
|
||||
|
||||
def test_list_prompt_types(self, config_dir):
|
||||
"""Test listing prompt types."""
|
||||
manager = PromptManager(config_dir / "prompts")
|
||||
types = manager.list_prompt_types()
|
||||
|
||||
assert "call_analysis" in types
|
||||
|
||||
def test_caching(self, config_dir):
|
||||
"""Test that prompts are cached."""
|
||||
manager = PromptManager(config_dir / "prompts")
|
||||
|
||||
template1 = manager.load("call_analysis", "v1.0")
|
||||
template2 = manager.load("call_analysis", "v1.0")
|
||||
|
||||
assert template1 is template2 # Same object
|
||||
|
||||
|
||||
class TestFormatFunctions:
|
||||
"""Tests for formatting helper functions."""
|
||||
|
||||
def test_format_events_empty(self):
|
||||
"""Test formatting with no events."""
|
||||
result = format_events_for_prompt([])
|
||||
assert "No significant events" in result
|
||||
|
||||
def test_format_events_with_events(self):
|
||||
"""Test formatting with events."""
|
||||
events = [
|
||||
Event(
|
||||
event_type=EventType.HOLD_START,
|
||||
start_time=10.0,
|
||||
),
|
||||
Event(
|
||||
event_type=EventType.SILENCE,
|
||||
start_time=30.0,
|
||||
duration_sec=8.0,
|
||||
),
|
||||
]
|
||||
|
||||
result = format_events_for_prompt(events)
|
||||
|
||||
assert "HOLD_START" in result
|
||||
assert "10.0s" in result
|
||||
assert "SILENCE" in result
|
||||
|
||||
def test_format_transcript_basic(self):
|
||||
"""Test basic transcript formatting."""
|
||||
turns = [
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="Hello",
|
||||
start_time=0.0,
|
||||
end_time=1.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="customer",
|
||||
text="Hi there",
|
||||
start_time=1.5,
|
||||
end_time=3.0,
|
||||
),
|
||||
]
|
||||
|
||||
result = format_transcript_for_prompt(turns)
|
||||
|
||||
assert "AGENT" in result
|
||||
assert "Hello" in result
|
||||
assert "CUSTOMER" in result
|
||||
assert "Hi there" in result
|
||||
|
||||
def test_format_transcript_truncation(self):
|
||||
"""Test transcript truncation."""
|
||||
turns = [
|
||||
SpeakerTurn(
|
||||
speaker="agent",
|
||||
text="A" * 5000, # Long text
|
||||
start_time=0.0,
|
||||
end_time=10.0,
|
||||
),
|
||||
SpeakerTurn(
|
||||
speaker="customer",
|
||||
text="B" * 5000, # Long text
|
||||
start_time=10.0,
|
||||
end_time=20.0,
|
||||
),
|
||||
]
|
||||
|
||||
result = format_transcript_for_prompt(turns, max_chars=6000)
|
||||
|
||||
assert "truncated" in result
|
||||
assert len(result) < 8000
|
||||
|
||||
|
||||
class TestAnalyzerValidation:
|
||||
"""Tests for analyzer validation logic."""
|
||||
|
||||
def test_evidence_required(self):
|
||||
"""Test that evidence is required for RCA labels."""
|
||||
from src.models.call_analysis import EvidenceSpan, RCALabel
|
||||
|
||||
# Valid: with evidence
|
||||
valid = RCALabel(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
confidence=0.9,
|
||||
evidence_spans=[
|
||||
EvidenceSpan(
|
||||
text="Es demasiado caro",
|
||||
start_time=10.0,
|
||||
end_time=12.0,
|
||||
)
|
||||
],
|
||||
)
|
||||
assert valid.driver_code == "PRICE_TOO_HIGH"
|
||||
|
||||
# Invalid: without evidence
|
||||
with pytest.raises(ValueError):
|
||||
RCALabel(
|
||||
driver_code="PRICE_TOO_HIGH",
|
||||
confidence=0.9,
|
||||
evidence_spans=[], # Empty
|
||||
)
|
||||
|
||||
def test_confidence_bounds(self):
|
||||
"""Test confidence must be 0-1."""
|
||||
from src.models.call_analysis import EvidenceSpan, RCALabel
|
||||
|
||||
evidence = [EvidenceSpan(text="test", start_time=0, end_time=1)]
|
||||
|
||||
# Valid
|
||||
valid = RCALabel(
|
||||
driver_code="TEST",
|
||||
confidence=0.5,
|
||||
evidence_spans=evidence,
|
||||
)
|
||||
assert valid.confidence == 0.5
|
||||
|
||||
# Invalid: > 1
|
||||
with pytest.raises(ValueError):
|
||||
RCALabel(
|
||||
driver_code="TEST",
|
||||
confidence=1.5,
|
||||
evidence_spans=evidence,
|
||||
)
|
||||
|
||||
def test_emergent_requires_proposed_label(self):
|
||||
"""Test OTHER_EMERGENT requires proposed_label."""
|
||||
from src.models.call_analysis import EvidenceSpan, RCALabel
|
||||
|
||||
evidence = [EvidenceSpan(text="test", start_time=0, end_time=1)]
|
||||
|
||||
# Valid: with proposed_label
|
||||
valid = RCALabel(
|
||||
driver_code="OTHER_EMERGENT",
|
||||
confidence=0.7,
|
||||
evidence_spans=evidence,
|
||||
proposed_label="NEW_PATTERN",
|
||||
)
|
||||
assert valid.proposed_label == "NEW_PATTERN"
|
||||
|
||||
# Invalid: without proposed_label
|
||||
with pytest.raises(ValueError):
|
||||
RCALabel(
|
||||
driver_code="OTHER_EMERGENT",
|
||||
confidence=0.7,
|
||||
evidence_spans=evidence,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_dir(project_root):
|
||||
"""Return the config directory."""
|
||||
return project_root / "config"
|
||||
Reference in New Issue
Block a user