Skip to content

Commit

Permalink
Implement one-time-salt use and add comprehensive tests
Browse files Browse the repository at this point in the history
  • Loading branch information
snoopdave committed Oct 5, 2024
1 parent 78a2c78 commit e2dbf4d
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha
throws IOException, ServletException {

HttpServletRequest httpReq = (HttpServletRequest) request;
RollerSession rses = RollerSession.getRollerSession(httpReq);
String userId = rses != null && rses.getAuthenticatedUser() != null ? rses.getAuthenticatedUser().getId() : "";

SaltCache saltCache = SaltCache.getInstance();
String salt = RandomStringUtils.random(20, 0, 0, true, true, null, new SecureRandom());
saltCache.put(salt, userId);
httpReq.setAttribute("salt", salt);
RollerSession rollerSession = RollerSession.getRollerSession(httpReq);
if (rollerSession != null) {
String userId = rollerSession.getAuthenticatedUser() != null ? rollerSession.getAuthenticatedUser().getId() : "";
SaltCache saltCache = SaltCache.getInstance();
String salt = RandomStringUtils.random(20, 0, 0, true, true, null, new SecureRandom());
saltCache.put(salt, userId);
httpReq.setAttribute("salt", salt);
}

chain.doFilter(request, response);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,24 @@ public void doFilter(ServletRequest request, ServletResponse response,
HttpServletRequest httpReq = (HttpServletRequest) request;

if ("POST".equals(httpReq.getMethod()) && !isIgnoredURL(httpReq.getServletPath())) {
RollerSession rses = RollerSession.getRollerSession(httpReq);
String userId = rses != null && rses.getAuthenticatedUser() != null ? rses.getAuthenticatedUser().getId() : "";
RollerSession rollerSession = RollerSession.getRollerSession(httpReq);
if (rollerSession != null) {
String userId = rollerSession.getAuthenticatedUser() != null ? rollerSession.getAuthenticatedUser().getId() : "";

String salt = httpReq.getParameter("salt");
SaltCache saltCache = SaltCache.getInstance();
if (salt == null || !Objects.equals(saltCache.get(salt), userId)) {
String salt = httpReq.getParameter("salt");
SaltCache saltCache = SaltCache.getInstance();
if (salt == null || !Objects.equals(saltCache.get(salt), userId)) {
if (log.isDebugEnabled()) {
log.debug("Valid salt value not found on POST to URL : " + httpReq.getServletPath());
}
throw new ServletException("Security Violation");
}

// Remove salt from cache after successful validation
saltCache.remove(salt);
if (log.isDebugEnabled()) {
log.debug("Valid salt value not found on POST to URL : " + httpReq.getServletPath());
log.debug("Salt used and invalidated: " + salt);
}
throw new ServletException("Security Violation");
}
}

Expand All @@ -70,8 +78,6 @@ public void doFilter(ServletRequest request, ServletResponse response,

@Override
public void init(FilterConfig filterConfig) throws ServletException {

// Construct our list of ignored urls
String urls = WebloggerConfig.getProperty("salt.ignored.urls");
ignored = Set.of(StringUtils.stripAll(StringUtils.split(urls, ",")));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package org.apache.roller.weblogger.ui.core.filters;

import org.apache.roller.weblogger.pojos.User;
import org.apache.roller.weblogger.ui.core.RollerSession;
import org.apache.roller.weblogger.ui.rendering.util.cache.SaltCache;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;

import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import static org.mockito.Mockito.*;

public class LoadSaltFilterTest {

private LoadSaltFilter filter;

@Mock
private HttpServletRequest request;

@Mock
private HttpServletResponse response;

@Mock
private FilterChain chain;

@Mock
private RollerSession rollerSession;

@Mock
private SaltCache saltCache;

@BeforeEach
public void setUp() {
MockitoAnnotations.initMocks(this);
filter = new LoadSaltFilter();
}

@Test
public void testDoFilterGeneratesSalt() throws Exception {
try (MockedStatic<RollerSession> mockedRollerSession = mockStatic(RollerSession.class);
MockedStatic<SaltCache> mockedSaltCache = mockStatic(SaltCache.class)) {

mockedRollerSession.when(() -> RollerSession.getRollerSession(request)).thenReturn(rollerSession);
mockedSaltCache.when(SaltCache::getInstance).thenReturn(saltCache);

when(rollerSession.getAuthenticatedUser()).thenReturn(new TestUser("userId"));

filter.doFilter(request, response, chain);

verify(request).setAttribute(eq("salt"), anyString());
verify(saltCache).put(anyString(), eq("userId"));
verify(chain).doFilter(request, response);
}
}

@Test
public void testDoFilterWithNullRollerSession() throws Exception {
try (MockedStatic<RollerSession> mockedRollerSession = mockStatic(RollerSession.class);
MockedStatic<SaltCache> mockedSaltCache = mockStatic(SaltCache.class)) {

mockedRollerSession.when(() -> RollerSession.getRollerSession(request)).thenReturn(null);
mockedSaltCache.when(SaltCache::getInstance).thenReturn(saltCache);

filter.doFilter(request, response, chain);

verify(request, never()).setAttribute(eq("salt"), anyString());
verify(saltCache, never()).put(anyString(), anyString());
verify(chain).doFilter(request, response);
}
}

private static class TestUser extends User {
private final String id;

TestUser(String id) {
this.id = id;
}

public String getId() {
return id;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package org.apache.roller.weblogger.ui.core.filters;

import org.apache.roller.weblogger.pojos.User;
import org.apache.roller.weblogger.ui.core.RollerSession;
import org.apache.roller.weblogger.ui.rendering.util.cache.SaltCache;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.*;

public class ValidateSaltFilterTest {

private ValidateSaltFilter filter;

@Mock
private HttpServletRequest request;

@Mock
private HttpServletResponse response;

@Mock
private FilterChain chain;

@Mock
private RollerSession rollerSession;

@Mock
private SaltCache saltCache;

@BeforeEach
public void setUp() {
MockitoAnnotations.openMocks(this);
filter = new ValidateSaltFilter();
}

@Test
public void testDoFilterWithGetMethod() throws Exception {
when(request.getMethod()).thenReturn("GET");

filter.doFilter(request, response, chain);

verify(chain).doFilter(request, response);
}

@Test
public void testDoFilterWithPostMethodAndValidSalt() throws Exception {
try (MockedStatic<RollerSession> mockedRollerSession = mockStatic(RollerSession.class);
MockedStatic<SaltCache> mockedSaltCache = mockStatic(SaltCache.class)) {

mockedRollerSession.when(() -> RollerSession.getRollerSession(request)).thenReturn(rollerSession);
mockedSaltCache.when(SaltCache::getInstance).thenReturn(saltCache);

when(request.getMethod()).thenReturn("POST");
when(request.getServletPath()).thenReturn("/someurl");
when(request.getParameter("salt")).thenReturn("validSalt");
when(saltCache.get("validSalt")).thenReturn("userId");
when(rollerSession.getAuthenticatedUser()).thenReturn(new TestUser("userId"));

filter.doFilter(request, response, chain);

verify(chain).doFilter(request, response);
verify(saltCache).remove("validSalt");
}
}

@Test
public void testDoFilterWithPostMethodAndInvalidSalt() throws Exception {
try (MockedStatic<RollerSession> mockedRollerSession = mockStatic(RollerSession.class);
MockedStatic<SaltCache> mockedSaltCache = mockStatic(SaltCache.class)) {

mockedRollerSession.when(() -> RollerSession.getRollerSession(request)).thenReturn(rollerSession);
mockedSaltCache.when(SaltCache::getInstance).thenReturn(saltCache);

when(request.getMethod()).thenReturn("POST");
when(request.getServletPath()).thenReturn("/someurl");
when(request.getParameter("salt")).thenReturn("invalidSalt");
when(saltCache.get("invalidSalt")).thenReturn(null);

assertThrows(ServletException.class, () -> {
filter.doFilter(request, response, chain);
});
}
}

@Test
public void testDoFilterWithPostMethodAndMismatchedUserId() throws Exception {
try (MockedStatic<RollerSession> mockedRollerSession = mockStatic(RollerSession.class);
MockedStatic<SaltCache> mockedSaltCache = mockStatic(SaltCache.class)) {

mockedRollerSession.when(() -> RollerSession.getRollerSession(request)).thenReturn(rollerSession);
mockedSaltCache.when(SaltCache::getInstance).thenReturn(saltCache);

when(request.getMethod()).thenReturn("POST");
when(request.getServletPath()).thenReturn("/someurl");
when(request.getParameter("salt")).thenReturn("validSalt");
when(saltCache.get("validSalt")).thenReturn("differentUserId");
when(rollerSession.getAuthenticatedUser()).thenReturn(new TestUser("userId"));

assertThrows(ServletException.class, () -> {
filter.doFilter(request, response, chain);
});
}
}

@Test
public void testDoFilterWithPostMethodAndNullRollerSession() throws Exception {
try (MockedStatic<RollerSession> mockedRollerSession = mockStatic(RollerSession.class);
MockedStatic<SaltCache> mockedSaltCache = mockStatic(SaltCache.class)) {

mockedRollerSession.when(() -> RollerSession.getRollerSession(request)).thenReturn(null);
mockedSaltCache.when(SaltCache::getInstance).thenReturn(saltCache);

when(request.getMethod()).thenReturn("POST");
when(request.getServletPath()).thenReturn("/someurl");
when(request.getParameter("salt")).thenReturn("validSalt");
when(saltCache.get("validSalt")).thenReturn("");

filter.doFilter(request, response, chain);

verify(saltCache, never()).remove("validSalt");
}
}
private static class TestUser extends User {
private final String id;

TestUser(String id) {
this.id = id;
}

@Override
public String getId() {
return id;
}
}
}

0 comments on commit e2dbf4d

Please sign in to comment.