Skip to content
76 changes: 71 additions & 5 deletions s3file/forms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import base64
import html
import logging
import pathlib
import uuid
from html.parser import HTMLParser

from django.conf import settings
from django.templatetags.static import static
Expand All @@ -16,6 +18,71 @@
logger = logging.getLogger("s3file")


class InputToS3FileRewriter(HTMLParser):
"""HTML parser that rewrites <input type="file"> to <s3-file> custom elements."""

def __init__(self):
super().__init__()
self.output = []

def handle_starttag(self, tag, attrs):
if tag == "input" and dict(attrs).get("type") == "file":
self.output.append("<s3-file")
for name, value in attrs:
if name != "type":
self.output.append(
f' {name}="{html.escape(value, quote=True)}"'
if value
else f" {name}"
)
self.output.append(">")
else:
self.output.append(self.get_starttag_text())

def handle_endtag(self, tag):
self.output.append(f"</{tag}>")

def handle_data(self, data):
self.output.append(data)

def handle_startendtag(self, tag, attrs):
if tag == "input" and dict(attrs).get("type") == "file":
self.output.append("<s3-file")
for name, value in attrs:
if name != "type":
self.output.append(
f' {name}="{html.escape(value, quote=True)}"'
if value
else f" {name}"
)
self.output.append(">")
else:
self.output.append(self.get_starttag_text())

def handle_comment(self, data):
# Preserve HTML comments in the output
self.output.append(f"<!--{data}-->")

def handle_decl(self, decl):
# Preserve declarations such as <!DOCTYPE ...> in the output
self.output.append(f"<!{decl}>")

def handle_pi(self, data):
# Preserve processing instructions such as <?xml ...?> in the output
self.output.append(f"<?{data}>")

def handle_entityref(self, name):
# Preserve HTML entities like &amp;, &lt;, &gt;
self.output.append(f"&{name};")

def handle_charref(self, name):
# Preserve character references like &#39;, &#x27;
self.output.append(f"&#{name};")

def get_html(self):
return "".join(self.output)


@html_safe
class Asset:
"""A generic asset that can be included in a template."""
Expand Down Expand Up @@ -99,11 +166,10 @@ def build_attrs(self, *args, **kwargs):

def render(self, name, value, attrs=None, renderer=None):
"""Render the widget as a custom element for Safari compatibility."""
return mark_safe( # noqa: S308
str(super().render(name, value, attrs=attrs, renderer=renderer)).replace(
f'<input type="{self.input_type}"', "<s3-file"
)
)
html_output = str(super().render(name, value, attrs=attrs, renderer=renderer))
parser = InputToS3FileRewriter()
parser.feed(html_output)
return mark_safe(parser.get_html()) # noqa: S308

def get_conditions(self, accept):
conditions = [
Expand Down
107 changes: 107 additions & 0 deletions tests/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,98 @@ def test_str(self, settings):
assert str(js) == '<script src="/static/path" type="module"></script>'


class TestInputToS3FileRewriter:
def test_transforms_file_input(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<input type="file" name="test">')
assert parser.get_html() == '<s3-file name="test">'

def test_preserves_non_file_input(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<input type="text" name="test">')
assert parser.get_html() == '<input type="text" name="test">'

def test_handles_attribute_ordering(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<input name="test" type="file" class="foo">')
result = parser.get_html()
assert result.startswith("<s3-file")
assert 'name="test"' in result
assert 'class="foo"' in result
assert 'type="file"' not in result

def test_handles_multiple_attributes(self):
parser = forms.InputToS3FileRewriter()
parser.feed(
'<input type="file" name="test" accept="image/*" required multiple>'
)
result = parser.get_html()
assert result.startswith("<s3-file")
assert 'name="test"' in result
assert 'accept="image/*"' in result
assert "required" in result
assert "multiple" in result

def test_escapes_html_entities(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<input type="file" name="test" data-value="test&value">')
result = parser.get_html()
assert 'data-value="test&amp;value"' in result

def test_preserves_existing_html_entities(self):
# Test that already-escaped entities in input are preserved (not double-escaped)
parser = forms.InputToS3FileRewriter()
parser.feed('<input type="file" name="test" data-value="test&amp;value">')
result = parser.get_html()
# Should preserve the &amp; entity, not convert to &amp;amp;
assert 'data-value="test&amp;value"' in result
assert '&amp;amp;' not in result

def test_preserves_character_references(self):
# Test that character references are preserved (may be in decimal or hex format)
parser = forms.InputToS3FileRewriter()
parser.feed('<input type="file" name="test" data-value="test&#39;s">')
result = parser.get_html()
# The character reference should be preserved (either &#39; or &#x27; both represent ')
assert ('data-value="test&#39;s"' in result or 'data-value="test&#x27;s"' in result)
# Verify the actual apostrophe character is NOT directly in the output (should be a reference)
assert 'data-value="test\'s"' not in result or '&#' in result

def test_handles_self_closing_tag(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<input type="file" name="test" />')
assert parser.get_html() == '<s3-file name="test">'

def test_preserves_non_file_self_closing_tag(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<input type="text" name="test" />')
assert parser.get_html() == '<input type="text" name="test" />'

def test_preserves_surrounding_elements(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<p><input type="file" name="test"></p>')
result = parser.get_html()
assert result == '<p><s3-file name="test"></p>'

def test_preserves_html_comments(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<!-- comment --><input type="file" name="test">')
result = parser.get_html()
assert result == '<!-- comment --><s3-file name="test">'

def test_preserves_declarations(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<!DOCTYPE html><input type="file" name="test">')
result = parser.get_html()
assert result == '<!DOCTYPE html><s3-file name="test">'

def test_preserves_processing_instructions(self):
parser = forms.InputToS3FileRewriter()
parser.feed('<?xml version="1.0"?><input type="file" name="test">')
result = parser.get_html()
assert result == '<?xml version="1.0"?><s3-file name="test">'


@contextmanager
def wait_for_page_load(driver, timeout=30):
old_page = driver.find_element(By.TAG_NAME, "html")
Expand Down Expand Up @@ -186,6 +278,21 @@ def test_render_wraps_in_s3_file_element(self, freeze_upload_folder):
# Check that the output is the s3-file custom element
assert html.startswith("<s3-file")

def test_render_preserves_attributes(self, freeze_upload_folder):
widget = ClearableFileInput(attrs={"class": "test-class", "accept": "image/*"})
html = widget.render(name="file", value=None)
assert html.startswith("<s3-file")
assert 'name="file"' in html
assert 'class="test-class"' in html
assert 'accept="image/*"' in html
assert 'type="file"' not in html

def test_render_excludes_type_attribute(self, freeze_upload_folder):
widget = ClearableFileInput()
html = widget.render(name="file", value=None)
assert 'type="file"' not in html
assert html.startswith("<s3-file")

@pytest.mark.selenium
def test_no_js_error(self, driver, live_server):
driver.get(live_server + self.create_url)
Expand Down
Loading