""" CXInsights - Inference Module Tests Tests for LLM client, prompt manager, and analyzer. Uses mocks for LLM calls to avoid API costs. """ import json from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest from src.inference.client import LLMClient, LLMClientConfig, LLMResponse from src.inference.prompt_manager import ( PromptManager, PromptTemplate, format_events_for_prompt, format_transcript_for_prompt, ) from src.models.call_analysis import Event, EventType from src.transcription.models import SpeakerTurn, Transcript, TranscriptMetadata class TestLLMResponse: """Tests for LLMResponse.""" def test_cost_estimate(self): """Test cost estimation based on tokens.""" response = LLMResponse( content="test", prompt_tokens=1000, completion_tokens=500, total_tokens=1500, ) # GPT-4o-mini: $0.15/1M input, $0.60/1M output expected = (1000 / 1_000_000) * 0.15 + (500 / 1_000_000) * 0.60 assert abs(response.cost_estimate_usd - expected) < 0.0001 def test_success_flag(self): """Test success flag.""" success = LLMResponse(content="test", success=True) failure = LLMResponse(content="", success=False, error="API error") assert success.success is True assert failure.success is False def test_parsed_json(self): """Test parsed JSON storage.""" response = LLMResponse( content='{"key": "value"}', parsed_json={"key": "value"}, ) assert response.parsed_json == {"key": "value"} class TestLLMClientConfig: """Tests for LLMClientConfig.""" def test_default_config(self): """Test default configuration.""" config = LLMClientConfig() assert config.model == "gpt-4o-mini" assert config.temperature == 0.1 assert config.max_tokens == 4000 assert config.json_mode is True def test_custom_config(self): """Test custom configuration.""" config = LLMClientConfig( model="gpt-4o", temperature=0.5, max_tokens=8000, ) assert config.model == "gpt-4o" assert config.temperature == 0.5 assert config.max_tokens == 8000 class TestLLMClient: """Tests for LLMClient.""" def test_requires_api_key(self): """Test that API key is required.""" with patch.dict("os.environ", {}, clear=True): with pytest.raises(ValueError, match="API key required"): LLMClient() def test_parse_json_valid(self): """Test JSON parsing with valid JSON.""" with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): client = LLMClient() result = client._parse_json('{"key": "value"}') assert result == {"key": "value"} def test_parse_json_with_markdown(self): """Test JSON parsing with markdown code blocks.""" with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): client = LLMClient() content = '```json\n{"key": "value"}\n```' result = client._parse_json(content) assert result == {"key": "value"} def test_parse_json_extract_from_text(self): """Test JSON extraction from surrounding text.""" with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): client = LLMClient() content = 'Here is the result: {"key": "value"} end.' result = client._parse_json(content) assert result == {"key": "value"} def test_parse_json_invalid(self): """Test JSON parsing with invalid JSON.""" with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): client = LLMClient() result = client._parse_json("not json at all") assert result is None def test_usage_stats_tracking(self): """Test usage statistics tracking.""" with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): client = LLMClient() # Initially zero stats = client.get_usage_stats() assert stats["total_calls"] == 0 assert stats["total_tokens"] == 0 def test_reset_usage_stats(self): """Test resetting usage statistics.""" with patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}): client = LLMClient() client._total_calls = 10 client._total_tokens = 5000 client.reset_usage_stats() stats = client.get_usage_stats() assert stats["total_calls"] == 0 assert stats["total_tokens"] == 0 class TestPromptTemplate: """Tests for PromptTemplate.""" def test_render_basic(self): """Test basic template rendering.""" template = PromptTemplate( name="test", version="v1.0", system="You are analyzing call $call_id", user="Transcript: $transcript", ) system, user = template.render( call_id="CALL001", transcript="Hello world", ) assert "CALL001" in system assert "Hello world" in user def test_render_missing_var(self): """Test rendering with missing variable (safe_substitute).""" template = PromptTemplate( name="test", version="v1.0", system="Call $call_id in $queue", user="Text", ) system, user = template.render(call_id="CALL001") # safe_substitute leaves $queue as-is assert "$queue" in system def test_to_messages(self): """Test message list generation.""" template = PromptTemplate( name="test", version="v1.0", system="System message", user="User message", ) messages = template.to_messages() assert len(messages) == 2 assert messages[0]["role"] == "system" assert messages[1]["role"] == "user" class TestPromptManager: """Tests for PromptManager.""" def test_load_call_analysis_prompt(self, config_dir): """Test loading call analysis prompt.""" manager = PromptManager(config_dir / "prompts") template = manager.load("call_analysis", "v1.0") assert template.name == "call_analysis" assert template.version == "v1.0" assert len(template.system) > 0 assert len(template.user) > 0 def test_load_nonexistent_prompt(self, config_dir): """Test loading non-existent prompt.""" manager = PromptManager(config_dir / "prompts") with pytest.raises(FileNotFoundError): manager.load("nonexistent", "v1.0") def test_get_active_version(self, config_dir): """Test getting active version.""" manager = PromptManager(config_dir / "prompts") version = manager.get_active_version("call_analysis") assert version == "v2.0" # Updated to v2.0 with Blueprint alignment def test_list_prompt_types(self, config_dir): """Test listing prompt types.""" manager = PromptManager(config_dir / "prompts") types = manager.list_prompt_types() assert "call_analysis" in types def test_caching(self, config_dir): """Test that prompts are cached.""" manager = PromptManager(config_dir / "prompts") template1 = manager.load("call_analysis", "v1.0") template2 = manager.load("call_analysis", "v1.0") assert template1 is template2 # Same object class TestFormatFunctions: """Tests for formatting helper functions.""" def test_format_events_empty(self): """Test formatting with no events.""" result = format_events_for_prompt([]) assert "No significant events" in result def test_format_events_with_events(self): """Test formatting with events.""" events = [ Event( event_type=EventType.HOLD_START, start_time=10.0, ), Event( event_type=EventType.SILENCE, start_time=30.0, duration_sec=8.0, ), ] result = format_events_for_prompt(events) assert "HOLD_START" in result assert "10.0s" in result assert "SILENCE" in result def test_format_transcript_basic(self): """Test basic transcript formatting.""" turns = [ SpeakerTurn( speaker="agent", text="Hello", start_time=0.0, end_time=1.0, ), SpeakerTurn( speaker="customer", text="Hi there", start_time=1.5, end_time=3.0, ), ] result = format_transcript_for_prompt(turns) assert "AGENT" in result assert "Hello" in result assert "CUSTOMER" in result assert "Hi there" in result def test_format_transcript_truncation(self): """Test transcript truncation.""" turns = [ SpeakerTurn( speaker="agent", text="A" * 5000, # Long text start_time=0.0, end_time=10.0, ), SpeakerTurn( speaker="customer", text="B" * 5000, # Long text start_time=10.0, end_time=20.0, ), ] result = format_transcript_for_prompt(turns, max_chars=6000) assert "truncated" in result assert len(result) < 8000 class TestAnalyzerValidation: """Tests for analyzer validation logic.""" def test_evidence_required(self): """Test that evidence is required for RCA labels.""" from src.models.call_analysis import EvidenceSpan, RCALabel # Valid: with evidence valid = RCALabel( driver_code="PRICE_TOO_HIGH", confidence=0.9, evidence_spans=[ EvidenceSpan( text="Es demasiado caro", start_time=10.0, end_time=12.0, ) ], ) assert valid.driver_code == "PRICE_TOO_HIGH" # Invalid: without evidence with pytest.raises(ValueError): RCALabel( driver_code="PRICE_TOO_HIGH", confidence=0.9, evidence_spans=[], # Empty ) def test_confidence_bounds(self): """Test confidence must be 0-1.""" from src.models.call_analysis import EvidenceSpan, RCALabel evidence = [EvidenceSpan(text="test", start_time=0, end_time=1)] # Valid valid = RCALabel( driver_code="TEST", confidence=0.5, evidence_spans=evidence, ) assert valid.confidence == 0.5 # Invalid: > 1 with pytest.raises(ValueError): RCALabel( driver_code="TEST", confidence=1.5, evidence_spans=evidence, ) def test_emergent_requires_proposed_label(self): """Test OTHER_EMERGENT requires proposed_label.""" from src.models.call_analysis import EvidenceSpan, RCALabel evidence = [EvidenceSpan(text="test", start_time=0, end_time=1)] # Valid: with proposed_label valid = RCALabel( driver_code="OTHER_EMERGENT", confidence=0.7, evidence_spans=evidence, proposed_label="NEW_PATTERN", ) assert valid.proposed_label == "NEW_PATTERN" # Invalid: without proposed_label with pytest.raises(ValueError): RCALabel( driver_code="OTHER_EMERGENT", confidence=0.7, evidence_spans=evidence, ) @pytest.fixture def config_dir(project_root): """Return the config directory.""" return project_root / "config"