Files
BeyondCX_Insights/tests/unit/test_inference.py
sujucu70 75e7b9da3d feat: Add Streamlit dashboard with Blueprint compliance (v2.1.0)
Dashboard Features:
- 8 navigation sections: Overview, Outcomes, Poor CX, FCR, Churn, Agent, Call Explorer, Export
- Beyond Brand Identity styling (colors #6D84E3, Outfit font)
- RCA Sankey diagram (Driver → Outcome → Churn Risk flow)
- Correlation heatmaps (driver co-occurrence, driver-outcome)
- Outcome Deep Dive (root causes, correlation, duration analysis)
- Export functionality (Excel, HTML, JSON)

Blueprint Compliance:
- FCR: 4 categories (Primera Llamada/Rellamada × Sin/Con Riesgo de Fuga)
- Churn: Binary view (Sin Riesgo de Fuga / En Riesgo de Fuga)
- Agent: Talento Para Replicar / Oportunidades de Mejora
- Fixed FCR rate calculation (only FIRST_CALL counts as success)

Technical:
- Streamlit + Plotly for interactive visualizations
- Light theme configuration (.streamlit/config.toml)
- Fixed Plotly colorbar titlefont deprecation

Documentation:
- Updated PROJECT_CONTEXT.md, TODO.md, CHANGELOG.md
- Added 4 new technical decisions (TD-014 to TD-017)
- Created TROUBLESHOOTING.md with 10 common issues

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 16:27:30 +01:00

394 lines
12 KiB
Python

"""
CXInsights - Inference Module Tests
Tests for LLM client, prompt manager, and analyzer.
Uses mocks for LLM calls to avoid API costs.
"""
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.inference.client import LLMClient, LLMClientConfig, LLMResponse
from src.inference.prompt_manager import (
PromptManager,
PromptTemplate,
format_events_for_prompt,
format_transcript_for_prompt,
)
from src.models.call_analysis import Event, EventType
from src.transcription.models import SpeakerTurn, Transcript, TranscriptMetadata
class TestLLMResponse:
"""Tests for LLMResponse."""
def test_cost_estimate(self):
"""Test cost estimation based on tokens."""
response = LLMResponse(
content="test",
prompt_tokens=1000,
completion_tokens=500,
total_tokens=1500,
)
# GPT-4o-mini: $0.15/1M input, $0.60/1M output
expected = (1000 / 1_000_000) * 0.15 + (500 / 1_000_000) * 0.60
assert abs(response.cost_estimate_usd - expected) < 0.0001
def test_success_flag(self):
"""Test success flag."""
success = LLMResponse(content="test", success=True)
failure = LLMResponse(content="", success=False, error="API error")
assert success.success is True
assert failure.success is False
def test_parsed_json(self):
"""Test parsed JSON storage."""
response = LLMResponse(
content='{"key": "value"}',
parsed_json={"key": "value"},
)
assert response.parsed_json == {"key": "value"}
class TestLLMClientConfig:
"""Tests for LLMClientConfig."""
def test_default_config(self):
"""Test default configuration."""
config = LLMClientConfig()
assert config.model == "gpt-4o-mini"
assert config.temperature == 0.1
assert config.max_tokens == 4000
assert config.json_mode is True
def test_custom_config(self):
"""Test custom configuration."""
config = LLMClientConfig(
model="gpt-4o",
temperature=0.5,
max_tokens=8000,
)
assert config.model == "gpt-4o"
assert config.temperature == 0.5
assert config.max_tokens == 8000
class TestLLMClient:
"""Tests for LLMClient."""
def test_requires_api_key(self):
"""Test that API key is required."""
with patch.dict("os.environ", {}, clear=True):
with pytest.raises(ValueError, match="API key required"):
LLMClient()
def test_parse_json_valid(self):
"""Test JSON parsing with valid JSON."""
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
client = LLMClient()
result = client._parse_json('{"key": "value"}')
assert result == {"key": "value"}
def test_parse_json_with_markdown(self):
"""Test JSON parsing with markdown code blocks."""
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
client = LLMClient()
content = '```json\n{"key": "value"}\n```'
result = client._parse_json(content)
assert result == {"key": "value"}
def test_parse_json_extract_from_text(self):
"""Test JSON extraction from surrounding text."""
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
client = LLMClient()
content = 'Here is the result: {"key": "value"} end.'
result = client._parse_json(content)
assert result == {"key": "value"}
def test_parse_json_invalid(self):
"""Test JSON parsing with invalid JSON."""
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
client = LLMClient()
result = client._parse_json("not json at all")
assert result is None
def test_usage_stats_tracking(self):
"""Test usage statistics tracking."""
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
client = LLMClient()
# Initially zero
stats = client.get_usage_stats()
assert stats["total_calls"] == 0
assert stats["total_tokens"] == 0
def test_reset_usage_stats(self):
"""Test resetting usage statistics."""
with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}):
client = LLMClient()
client._total_calls = 10
client._total_tokens = 5000
client.reset_usage_stats()
stats = client.get_usage_stats()
assert stats["total_calls"] == 0
assert stats["total_tokens"] == 0
class TestPromptTemplate:
"""Tests for PromptTemplate."""
def test_render_basic(self):
"""Test basic template rendering."""
template = PromptTemplate(
name="test",
version="v1.0",
system="You are analyzing call $call_id",
user="Transcript: $transcript",
)
system, user = template.render(
call_id="CALL001",
transcript="Hello world",
)
assert "CALL001" in system
assert "Hello world" in user
def test_render_missing_var(self):
"""Test rendering with missing variable (safe_substitute)."""
template = PromptTemplate(
name="test",
version="v1.0",
system="Call $call_id in $queue",
user="Text",
)
system, user = template.render(call_id="CALL001")
# safe_substitute leaves $queue as-is
assert "$queue" in system
def test_to_messages(self):
"""Test message list generation."""
template = PromptTemplate(
name="test",
version="v1.0",
system="System message",
user="User message",
)
messages = template.to_messages()
assert len(messages) == 2
assert messages[0]["role"] == "system"
assert messages[1]["role"] == "user"
class TestPromptManager:
"""Tests for PromptManager."""
def test_load_call_analysis_prompt(self, config_dir):
"""Test loading call analysis prompt."""
manager = PromptManager(config_dir / "prompts")
template = manager.load("call_analysis", "v1.0")
assert template.name == "call_analysis"
assert template.version == "v1.0"
assert len(template.system) > 0
assert len(template.user) > 0
def test_load_nonexistent_prompt(self, config_dir):
"""Test loading non-existent prompt."""
manager = PromptManager(config_dir / "prompts")
with pytest.raises(FileNotFoundError):
manager.load("nonexistent", "v1.0")
def test_get_active_version(self, config_dir):
"""Test getting active version."""
manager = PromptManager(config_dir / "prompts")
version = manager.get_active_version("call_analysis")
assert version == "v2.0" # Updated to v2.0 with Blueprint alignment
def test_list_prompt_types(self, config_dir):
"""Test listing prompt types."""
manager = PromptManager(config_dir / "prompts")
types = manager.list_prompt_types()
assert "call_analysis" in types
def test_caching(self, config_dir):
"""Test that prompts are cached."""
manager = PromptManager(config_dir / "prompts")
template1 = manager.load("call_analysis", "v1.0")
template2 = manager.load("call_analysis", "v1.0")
assert template1 is template2 # Same object
class TestFormatFunctions:
"""Tests for formatting helper functions."""
def test_format_events_empty(self):
"""Test formatting with no events."""
result = format_events_for_prompt([])
assert "No significant events" in result
def test_format_events_with_events(self):
"""Test formatting with events."""
events = [
Event(
event_type=EventType.HOLD_START,
start_time=10.0,
),
Event(
event_type=EventType.SILENCE,
start_time=30.0,
duration_sec=8.0,
),
]
result = format_events_for_prompt(events)
assert "HOLD_START" in result
assert "10.0s" in result
assert "SILENCE" in result
def test_format_transcript_basic(self):
"""Test basic transcript formatting."""
turns = [
SpeakerTurn(
speaker="agent",
text="Hello",
start_time=0.0,
end_time=1.0,
),
SpeakerTurn(
speaker="customer",
text="Hi there",
start_time=1.5,
end_time=3.0,
),
]
result = format_transcript_for_prompt(turns)
assert "AGENT" in result
assert "Hello" in result
assert "CUSTOMER" in result
assert "Hi there" in result
def test_format_transcript_truncation(self):
"""Test transcript truncation."""
turns = [
SpeakerTurn(
speaker="agent",
text="A" * 5000, # Long text
start_time=0.0,
end_time=10.0,
),
SpeakerTurn(
speaker="customer",
text="B" * 5000, # Long text
start_time=10.0,
end_time=20.0,
),
]
result = format_transcript_for_prompt(turns, max_chars=6000)
assert "truncated" in result
assert len(result) < 8000
class TestAnalyzerValidation:
"""Tests for analyzer validation logic."""
def test_evidence_required(self):
"""Test that evidence is required for RCA labels."""
from src.models.call_analysis import EvidenceSpan, RCALabel
# Valid: with evidence
valid = RCALabel(
driver_code="PRICE_TOO_HIGH",
confidence=0.9,
evidence_spans=[
EvidenceSpan(
text="Es demasiado caro",
start_time=10.0,
end_time=12.0,
)
],
)
assert valid.driver_code == "PRICE_TOO_HIGH"
# Invalid: without evidence
with pytest.raises(ValueError):
RCALabel(
driver_code="PRICE_TOO_HIGH",
confidence=0.9,
evidence_spans=[], # Empty
)
def test_confidence_bounds(self):
"""Test confidence must be 0-1."""
from src.models.call_analysis import EvidenceSpan, RCALabel
evidence = [EvidenceSpan(text="test", start_time=0, end_time=1)]
# Valid
valid = RCALabel(
driver_code="TEST",
confidence=0.5,
evidence_spans=evidence,
)
assert valid.confidence == 0.5
# Invalid: > 1
with pytest.raises(ValueError):
RCALabel(
driver_code="TEST",
confidence=1.5,
evidence_spans=evidence,
)
def test_emergent_requires_proposed_label(self):
"""Test OTHER_EMERGENT requires proposed_label."""
from src.models.call_analysis import EvidenceSpan, RCALabel
evidence = [EvidenceSpan(text="test", start_time=0, end_time=1)]
# Valid: with proposed_label
valid = RCALabel(
driver_code="OTHER_EMERGENT",
confidence=0.7,
evidence_spans=evidence,
proposed_label="NEW_PATTERN",
)
assert valid.proposed_label == "NEW_PATTERN"
# Invalid: without proposed_label
with pytest.raises(ValueError):
RCALabel(
driver_code="OTHER_EMERGENT",
confidence=0.7,
evidence_spans=evidence,
)
@pytest.fixture
def config_dir(project_root):
"""Return the config directory."""
return project_root / "config"