200 lines
7.7 KiB
Python
200 lines
7.7 KiB
Python
from unittest.mock import MagicMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
from openai import RateLimitError
|
|
|
|
import translate
|
|
from translate import translate_feed, translate_html
|
|
|
|
|
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
|
|
|
def _openai_response(text: str) -> MagicMock:
|
|
m = MagicMock()
|
|
m.choices = [MagicMock()]
|
|
m.choices[0].message.content = text
|
|
return m
|
|
|
|
|
|
def _rate_limit_error() -> RateLimitError:
|
|
req = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
|
|
resp = httpx.Response(429, request=req)
|
|
return RateLimitError("Rate limit exceeded", response=resp, body=None)
|
|
|
|
|
|
# ── translate() ───────────────────────────────────────────────────────────────
|
|
|
|
def test_translate_empty_string_unchanged(db_conn):
|
|
with patch("translate._client") as mock_client:
|
|
result = translate.translate("", "sk", db_conn)
|
|
assert result == ""
|
|
mock_client.chat.completions.create.assert_not_called()
|
|
|
|
|
|
def test_translate_whitespace_unchanged(db_conn):
|
|
with patch("translate._client") as mock_client:
|
|
result = translate.translate(" ", "sk", db_conn)
|
|
assert result == " "
|
|
mock_client.chat.completions.create.assert_not_called()
|
|
|
|
|
|
def test_translate_calls_openai_on_cache_miss(db_conn):
|
|
with patch("translate._client") as mock_client:
|
|
mock_client.chat.completions.create.return_value = _openai_response("Ahoj")
|
|
result = translate.translate("Hello", "sk", db_conn)
|
|
assert result == "Ahoj"
|
|
mock_client.chat.completions.create.assert_called_once()
|
|
|
|
|
|
def test_translate_returns_cache_on_hit(db_conn):
|
|
with patch("translate._client") as mock_client:
|
|
mock_client.chat.completions.create.return_value = _openai_response("Ahoj")
|
|
translate.translate("Hello", "sk", db_conn)
|
|
result = translate.translate("Hello", "sk", db_conn)
|
|
assert result == "Ahoj"
|
|
assert mock_client.chat.completions.create.call_count == 1
|
|
|
|
|
|
def test_translate_on_error_returns_original(db_conn):
|
|
with patch("translate._client") as mock_client:
|
|
mock_client.chat.completions.create.side_effect = Exception("API down")
|
|
result = translate.translate("Hello", "sk", db_conn)
|
|
assert result == "Hello"
|
|
|
|
|
|
def test_translate_error_not_cached(db_conn):
|
|
with patch("translate._client") as mock_client:
|
|
mock_client.chat.completions.create.side_effect = Exception("API down")
|
|
translate.translate("Hello", "sk", db_conn)
|
|
|
|
mock_client.chat.completions.create.side_effect = None
|
|
mock_client.chat.completions.create.return_value = _openai_response("Ahoj")
|
|
result = translate.translate("Hello", "sk", db_conn)
|
|
|
|
assert result == "Ahoj"
|
|
assert mock_client.chat.completions.create.call_count == 2
|
|
|
|
|
|
# ── _call_openai() retry / backoff ────────────────────────────────────────────
|
|
|
|
def test_call_openai_retries_on_rate_limit(db_conn):
|
|
with patch("translate.time.sleep"), \
|
|
patch("translate._client") as mock_client:
|
|
mock_client.chat.completions.create.side_effect = [
|
|
_rate_limit_error(),
|
|
_rate_limit_error(),
|
|
_openai_response("Preložený"),
|
|
]
|
|
result = translate._call_openai("Hello", "sk")
|
|
assert result == "Preložený"
|
|
assert mock_client.chat.completions.create.call_count == 3
|
|
|
|
|
|
def test_call_openai_raises_after_max_retries():
|
|
with patch("translate.time.sleep"), \
|
|
patch("translate._client") as mock_client:
|
|
mock_client.chat.completions.create.side_effect = _rate_limit_error()
|
|
with pytest.raises(RateLimitError):
|
|
translate._call_openai("Hello", "sk")
|
|
assert mock_client.chat.completions.create.call_count == 5
|
|
|
|
|
|
def test_call_openai_backoff_increases():
|
|
with patch("translate.time.sleep") as mock_sleep, \
|
|
patch("translate._client") as mock_client:
|
|
mock_client.chat.completions.create.side_effect = [
|
|
_rate_limit_error(),
|
|
_rate_limit_error(),
|
|
_rate_limit_error(),
|
|
_openai_response("ok"),
|
|
]
|
|
translate._call_openai("Hello", "sk")
|
|
delays = [c.args[0] for c in mock_sleep.call_args_list]
|
|
assert delays == [1.0, 2.0, 4.0]
|
|
|
|
|
|
# ── translate_html() ──────────────────────────────────────────────────────────
|
|
|
|
def test_translate_html_preserves_img_tags(db_conn):
|
|
with patch("translate._client") as mock_client:
|
|
mock_client.chat.completions.create.return_value = _openai_response("<p>Ahoj</p>")
|
|
result = translate_html('<p>Hello</p><img src="photo.jpg" alt="x">', "sk", db_conn)
|
|
assert '<img src="photo.jpg" alt="x">' in result
|
|
assert "Ahoj" in result
|
|
|
|
|
|
def test_translate_html_strips_img_before_api_call(db_conn):
|
|
with patch("translate._client") as mock_client:
|
|
mock_client.chat.completions.create.return_value = _openai_response("<p>Ahoj</p>")
|
|
translate_html('<p>Hello</p><img src="photo.jpg">', "sk", db_conn)
|
|
sent = mock_client.chat.completions.create.call_args[1]["messages"][1]["content"]
|
|
assert "<img" not in sent
|
|
|
|
|
|
# ── translate_feed() ──────────────────────────────────────────────────────────
|
|
|
|
_RSS = b"""<?xml version="1.0" encoding="UTF-8"?>
|
|
<rss version="2.0">
|
|
<channel>
|
|
<title>Channel</title>
|
|
<item>
|
|
<title>Hello</title>
|
|
<description>World</description>
|
|
</item>
|
|
</channel>
|
|
</rss>"""
|
|
|
|
_ATOM = b"""<?xml version="1.0" encoding="UTF-8"?>
|
|
<feed xmlns="http://www.w3.org/2005/Atom">
|
|
<title>Feed</title>
|
|
<entry>
|
|
<title>Hello</title>
|
|
<summary>World</summary>
|
|
</entry>
|
|
</feed>"""
|
|
|
|
_EMPTY_TAGS = b"""<?xml version="1.0" encoding="UTF-8"?>
|
|
<rss version="2.0">
|
|
<channel>
|
|
<item>
|
|
<title></title>
|
|
<description/>
|
|
</item>
|
|
</channel>
|
|
</rss>"""
|
|
|
|
|
|
def test_translate_feed_rss_translates_content(db_conn):
|
|
with patch("translate._client") as mock_client:
|
|
mock_client.chat.completions.create.return_value = _openai_response("Ahoj")
|
|
result = translate_feed(_RSS, "sk", db_conn)
|
|
assert b"Ahoj" in result
|
|
|
|
|
|
def test_translate_feed_rss_calls_openai_per_tag(db_conn):
|
|
with patch("translate._client") as mock_client:
|
|
mock_client.chat.completions.create.return_value = _openai_response("X")
|
|
translate_feed(_RSS, "sk", db_conn)
|
|
# title (plain) + description (html) = 2 distinct texts → 2 calls
|
|
assert mock_client.chat.completions.create.call_count == 2
|
|
|
|
|
|
def test_translate_feed_invalid_xml_raises(db_conn):
|
|
with pytest.raises(ValueError, match="Invalid XML"):
|
|
translate_feed(b"not xml <<<", "sk", db_conn)
|
|
|
|
|
|
def test_translate_feed_empty_tags_no_crash(db_conn):
|
|
with patch("translate._client") as mock_client:
|
|
mock_client.chat.completions.create.return_value = _openai_response("X")
|
|
result = translate_feed(_EMPTY_TAGS, "sk", db_conn)
|
|
assert b"<rss" in result
|
|
mock_client.chat.completions.create.assert_not_called()
|
|
|
|
|
|
def test_translate_feed_atom_does_not_crash(db_conn):
|
|
with patch("translate._client") as mock_client:
|
|
mock_client.chat.completions.create.return_value = _openai_response("X")
|
|
result = translate_feed(_ATOM, "sk", db_conn)
|
|
assert b"<feed" in result
|