diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index aeba3b916..704fe0cf4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -45,7 +45,7 @@ jobs: with: distribution: temurin java-version: '24' - - uses: axel-op/googlejavaformat-action@fe78db8a90171b6a836449f8d0e982d5d71e5c5a #v3.6.0 + - uses: axel-op/googlejavaformat-action@c1134ebd196c4cbffb077f9476585b0be8b6afcd #v4 with: args: "--set-exit-if-changed" diff --git a/AGENTS.md b/AGENTS.md index 634aadaa9..f61b90eec 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -57,6 +57,14 @@ docker run -d --name secshep_test_db \ The docker-compose `db` service is **not suitable for quick test runs** — it requires `mvn -Pdocker validate` first to generate SQL init scripts, and those scripts run on first container startup to create all challenge schemas. +### Docker DB init fails with SQL syntax error + +If the MariaDB container exits immediately with `ERROR 1064 (42000) at line 181` referencing a `CREATE PROCEDURE` statement, this is a **stale Docker image cache** issue, not a bug in the SQL. + +The SQL source files have `DELIMITER` statements commented out (for compatibility with tools that don't support `DELIMITER`). A build script (`docker/scripts/convert-sql-scripts.sh`) uncomments them when Maven copies the files to `docker/mariadb/target/`. If Docker reuses a cached image from before the conversion ran, the procedures fail to parse. + +Fix: `docker compose build --no-cache db && docker compose up -d db` + ### Test credentials Tests read DB connection details from `.env` via dotenv. The key values: @@ -76,6 +84,35 @@ TEST_MYSQL_PASSWORD=CowSaysMoo mvn test -B Do not commit `.env` changes that break CI. +### First-time app setup + +After `docker compose up`, the app redirects to `https://localhost/setup.jsp` for initial database configuration. **Ask the user before performing setup** — they may prefer to configure it themselves via the browser. If they ask you to do it: + +1. Get the auth token: + ```bash + docker exec secshep_tomcat cat /usr/local/tomcat/conf/SecurityShepherd.auth + ``` +2. Submit the setup via curl (the TLS cert is self-signed, use `-k`): + ```bash + curl -k -s -X POST https://localhost/setup \ + -d "dbhost=secshep_mariadb" \ + -d "dbport=3306" \ + -d "dbuser=root" \ + -d "dbpass=CowSaysMoo" \ + -d "dboverride=override" \ + -d "dbauth=" \ + -d "mhost=secshep_mongo" \ + -d "mport=27017" + ``` + +The setup servlet parameter names (from `Setup.java`) are: +- `dbhost`, `dbport`, `dbuser`, `dbpass` — MySQL/MariaDB connection +- `dboverride` — set to `override` to reinitialize schemas +- `dbauth` — the auth token (NOT `authToken`) +- `mhost`, `mport` — MongoDB connection (required, even if not using mongo challenges) + +The hostname must be the **Docker container name** (e.g. `secshep_mariadb`), not `localhost`, since the Tomcat container connects over the Docker network. + ## Git workflow - Never commit directly to `master` or `dev` diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9636318c9..4dfc70f38 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -43,6 +43,18 @@ Install ZenHub for your browser and click the ZenHub tab that will appear in thi ## How do I setup my dev environment? [Like This](https://github.com/OWASP/SecurityShepherd/wiki/Create-a-Security-Shepherd-Dev-Environment) +## Development Workflow +See [docs/development-workflow.md](docs/development-workflow.md) for the full development cycle, including how to: +- Create your branch and make changes +- Build and test your changes in the runtime environment using Docker +- Run the automated test suite + +## Running Tests +See [docs/testing.md](docs/testing.md) for instructions on running unit and integration tests with Docker. + +## Database Configuration +See [docs/database-configuration.md](docs/database-configuration.md) for connection pooling configuration and database setup. + ## Is there a Definition of Done? *Work in Progess* - [ ] New Code has 'Good' JUnit Tests that cover it diff --git a/docs/database-configuration.md b/docs/database-configuration.md new file mode 100644 index 000000000..31f2e75a9 --- /dev/null +++ b/docs/database-configuration.md @@ -0,0 +1,285 @@ +# Database Configuration + +This guide covers the database connection configuration for Security Shepherd, including the connection pooling implementation. + +## Overview + +Security Shepherd uses two databases: + +- **MariaDB/MySQL** - Main application database (user accounts, progress, modules) +- **MongoDB** - Challenge-specific data storage + +## Why Connection Pooling? + +### The Problem Without Pooling + +Without connection pooling, every database operation requires: +1. Opening a new TCP connection +2. Performing authentication handshake +3. Executing the query +4. Closing the connection + +This is expensive - each connection can take 20-50ms to establish, which adds up quickly under load. + +### Pros of Connection Pooling + +| Benefit | Description | +|---------|-------------| +| **Performance** | Connections are reused, eliminating connection overhead per request | +| **Resource Efficiency** | Limits total connections, preventing database overload | +| **Connection Validation** | Pool validates connections before use, catching stale connections | +| **Configurable** | Tune pool size, timeouts, and behavior for your workload | +| **Monitoring** | Pool provides statistics on usage, active connections, and wait times | + +### Cons of Connection Pooling + +| Drawback | Description | +|----------|-------------| +| **Memory Overhead** | Idle connections consume memory on both app and database servers | +| **Configuration Complexity** | Wrong pool size can cause issues (too small = contention, too large = resource waste) | +| **Connection Leaks** | Code that doesn't close connections properly can exhaust the pool | +| **Stale Connections** | Long-idle connections may be terminated by firewalls or the database | +| **Debugging Complexity** | Connection issues can be harder to diagnose with pooling | + +### When Pooling Makes Sense + +Connection pooling is beneficial when: +- Your application handles many concurrent requests +- Database operations are frequent +- Connection establishment time is significant relative to query time + +For Security Shepherd, pooling improves performance during: +- Multiple users accessing challenges simultaneously +- Scoreboard updates and leaderboard queries +- User authentication and session management + +## Connection Pooling Implementation + +### MySQL/MariaDB - HikariCP + +Security Shepherd uses [HikariCP](https://github.com/brettwooldridge/HikariCP) for MySQL/MariaDB connection pooling. HikariCP is a high-performance JDBC connection pool. + +**Benefits:** +- Efficient connection reuse +- Reduced connection overhead +- Configurable pool size +- Connection validation and health checks + +### MongoDB - Singleton Pattern + +MongoDB connections use a singleton `MongoClient` instance. The MongoDB Java driver includes built-in connection pooling, so we maintain a single client instance that's reused across the application. + +## Configuration Files + +### MySQL/MariaDB Configuration + +File: `src/main/resources/database.properties` + +See example: `src/main/resources/database.properties.example` + +```properties +# Database connection URL (without schema) +databaseConnectionURL=jdbc:mysql://localhost:3306/ + +# JDBC driver class +DriverType=org.gjt.mm.mysql.Driver + +# Connection options +databaseOptions=useUnicode=true&character_set_server=utf8mb4 + +# Database schema name +databaseSchema=core + +# Credentials +databaseUsername=root +databasePassword=your_password + +# HikariCP Pool Settings (optional - defaults shown) +pool.maximumPoolSize=10 +pool.minimumIdle=2 +pool.connectionTimeout=30000 +pool.idleTimeout=600000 +pool.maxLifetime=1800000 +pool.poolName=SecurityShepherdPool +``` + +### Pool Configuration Options + +| Property | Default | Description | +|----------|---------|-------------| +| `pool.maximumPoolSize` | 10 | Maximum number of connections in the pool | +| `pool.minimumIdle` | 2 | Minimum number of idle connections to maintain | +| `pool.connectionTimeout` | 30000 | Maximum time (ms) to wait for a connection | +| `pool.idleTimeout` | 600000 | Maximum time (ms) a connection can be idle | +| `pool.maxLifetime` | 1800000 | Maximum lifetime (ms) of a connection | +| `pool.leakDetectionThreshold` | 60000 | Logs a warning if a connection is held longer than this (ms). Set to 0 to disable | +| `pool.poolName` | SecurityShepherdPool | Name for the pool (appears in logs) | + +### MongoDB Configuration + +File: `src/main/resources/mongo.properties` + +See example: `src/main/resources/mongo.properties.example` + +```properties +# MongoDB connection settings +connectionHost=localhost +connectionPort=27017 +databaseName=shepherdGames +connectTimeout=1000 +socketTimeout=0 +serverSelectionTimeout=30000 + +# Connection pool settings (optional) +connectionsPerHost=10 +minConnectionsPerHost=2 +``` + +## Lifecycle Management + +Connection pools are managed by `DatabaseLifecycleListener`, which: + +1. **On application startup**: Initializes the HikariCP connection pool +2. **On application shutdown**: Closes all database connections gracefully + +This is registered in `web.xml`: + +```xml + + listeners.DatabaseLifecycleListener + +``` + +## Monitoring + +### Pool Statistics + +The `ConnectionPool` class provides pool statistics: + +```java +String stats = ConnectionPool.getPoolStats(); +// Returns: Pool: SecurityShepherdPool | Total: 10 | Active: 2 | Idle: 8 | Waiting: 0 +``` + +### Logging + +Connection pool events are logged at various levels: + +- **INFO**: Pool initialization and shutdown +- **DEBUG**: Connection acquisition and pool statistics +- **WARN**: Connection validation failures +- **ERROR**: Pool initialization failures + +## Docker Configuration + +When running with Docker, database hosts are configured via environment variables in `.env`: + +```properties +# MariaDB +CONTAINER_MARIADB=secshep_mariadb +DB_PASS=your_password +DB_PORT=3306 + +# MongoDB +CONTAINER_MONGO=secshep_mongo +``` + +The Tomcat container connects to databases using container names (e.g., `secshep_mariadb`) as hostnames within the Docker network. + +### First-Time Setup + +After starting the stack with `docker compose up`, the app will redirect to `https://localhost/setup.jsp` for initial database configuration. + +> **Note:** The TLS certificate is self-signed. Your browser will show a security warning — accept it to proceed. + +#### Step 1: Get the authentication token + +The setup page requires a token from the server's filesystem to prevent unauthorized configuration: + +```bash +docker exec secshep_tomcat cat /usr/local/tomcat/conf/SecurityShepherd.auth +``` + +#### Step 2: Fill in the setup form + +Navigate to `https://localhost/setup.jsp` and fill in: + +| Field | Value | Notes | +|-------|-------|-------| +| **Hostname** | `secshep_mariadb` | The Docker container name — **not** `localhost` | +| **Port** | `3306` | | +| **DB Username** | `root` | | +| **DB Password** | `CowSaysMoo` | Must match `DB_PASS` in `.env` | +| **Override Databases** | checked | Initializes all challenge schemas on first setup | +| **MongoDB Host** | `secshep_mongo` | The Docker container name — **not** `localhost` | +| **MongoDB Port** | `27017` | | +| **Authentication token** | (paste from step 1) | | + +#### Step 3: Submit + +Click submit. On success you will see "Database Configuration Complete" and be redirected to the login page. + +#### Alternative: setup via curl + +```bash +AUTH=$(docker exec secshep_tomcat cat /usr/local/tomcat/conf/SecurityShepherd.auth) +curl -k -s -X POST https://localhost/setup \ + -d "dbhost=secshep_mariadb" \ + -d "dbport=3306" \ + -d "dbuser=root" \ + -d "dbpass=CowSaysMoo" \ + -d "dboverride=override" \ + -d "dbauth=$AUTH" \ + -d "mhost=secshep_mongo" \ + -d "mport=27017" +``` + +> **Important:** The hostnames must be Docker container names (e.g. `secshep_mariadb`, `secshep_mongo`), not `localhost`. The Tomcat container connects to the databases over the Docker network, where containers are addressed by name. + +## Troubleshooting + +### Pool Exhaustion + +If you see "Connection is not available" errors: + +1. Check for connection leaks (connections not being closed) +2. Increase `pool.maximumPoolSize` +3. Review `pool.connectionTimeout` setting + +### Slow Connection Acquisition + +If connections are slow to acquire: + +1. Check database server health +2. Review network latency between app and database +3. Consider increasing `pool.minimumIdle` + +### Connection Validation Failures + +If connections are failing validation: + +1. Check database server is running +2. Verify credentials in properties file +3. Check network connectivity + +### MariaDB Container Exits with SQL Syntax Error + +If the MariaDB container exits immediately on first startup with an error like: + +``` +ERROR 1064 (42000) at line 181: You have an error in your SQL syntax; +check the manual that corresponds to your MariaDB server version for the right syntax to use near '' at line 3 +``` + +This is a **stale Docker image cache** issue. The SQL source files in `src/main/resources/database/` have `DELIMITER` statements commented out for compatibility with tools that don't support `DELIMITER`. During the Maven build (`mvn -Pdocker`), a script (`docker/scripts/convert-sql-scripts.sh`) uncomments them in the copies under `docker/mariadb/target/`. If Docker reuses a cached image from before this conversion ran, the stored procedures fail to parse. + +**Fix:** + +```bash +mvn -Pdocker validate # ensure SQL scripts are converted +docker compose build --no-cache db # rebuild without cache +docker compose down -v # remove old volumes with failed init +docker compose up -d db # start fresh +``` + +Note: `docker compose down -v` is needed because MariaDB only runs init scripts on first startup with an empty data volume. If the previous attempt partially initialized, the scripts won't re-run without removing the volume. diff --git a/docs/development-workflow.md b/docs/development-workflow.md new file mode 100644 index 000000000..7446d6400 --- /dev/null +++ b/docs/development-workflow.md @@ -0,0 +1,142 @@ +# Development Workflow + +This guide explains the full development cycle for contributing to Security Shepherd. + +## 1. Create Your Branch + +Fork the repository or create a branch from the `dev` branch: + +```bash +git checkout dev +git pull origin dev +git checkout -b "dev#" +``` + +Branch naming convention: `dev#` (e.g., `dev#536`) + +## 2. Make Your Changes + +Edit the code in your branch. Key directories: + +- `src/main/java/` - Java source code +- `src/main/webapp/` - Web resources (JSP, CSS, JS) +- `src/test/java/` - Unit tests + +## 3. Build the WAR + +Build the application with Maven: + +```bash +mvn -Pdocker clean install -DskipTests +``` + +This generates: +- The WAR file in `target/` +- HTTPS certificates for Docker + +## 4. Test in Runtime Environment + +Start the full application stack using Docker Compose: + +```bash +# Build and start all containers (MariaDB, MongoDB, Tomcat) +docker-compose up --build + +# Or run in detached mode +docker-compose up -d --build +``` + +Access the application: +- **HTTP**: http://localhost +- **HTTPS**: https://localhost:8443 + +Default login credentials: +- Username: `admin` +- Password: `password` + +### Viewing Logs + +```bash +# View all logs +docker-compose logs -f + +# View only web container logs +docker-compose logs -f web +``` + +### Rebuilding After Code Changes + +After making additional changes: + +```bash +# Rebuild the WAR +mvn -Pdocker clean install -DskipTests + +# Rebuild and restart only the web container +docker-compose up -d --build web +``` + +### Stopping the Environment + +```bash +docker-compose down +``` + +To also remove volumes (database data): + +```bash +docker-compose down -v +``` + +## 5. Run Automated Tests + +See [testing.md](testing.md) for instructions on running the unit and integration test suite. + +## 6. Submit Your Pull Request + +When your changes are complete and tests pass: + +1. Push your branch to your fork/origin +2. Create a Pull Request targeting the `dev` branch +3. Ensure all CI checks pass + +See [CONTRIBUTING.md](../CONTRIBUTING.md) for code formatting and PR guidelines. + +## Environment Configuration + +The `.env` file in the project root contains environment variables for Docker. Key variables: + +| Variable | Description | +|----------|-------------| +| `DB_PASS` | MariaDB root password | +| `HTTP_PORT` | HTTP port (default: 80) | +| `HTTPS_PORT` | HTTPS port (default: 8443) | + +## Troubleshooting + +### Container Won't Start + +Check if ports are already in use: + +```bash +# Check if port 80 is in use +lsof -i :80 +``` + +### Database Connection Issues + +Ensure the database container is healthy: + +```bash +docker-compose ps +docker-compose logs db +``` + +### Changes Not Reflected + +Make sure you rebuilt both the WAR and the container: + +```bash +mvn -Pdocker clean install -DskipTests +docker-compose up -d --build web +``` diff --git a/docs/testing.md b/docs/testing.md new file mode 100644 index 000000000..560c7fce6 --- /dev/null +++ b/docs/testing.md @@ -0,0 +1,144 @@ +# Running Tests + +This guide explains how to run the automated test suite for Security Shepherd. + +## Prerequisites + +- **Docker** and **docker-compose** installed +- **Maven** 3.x +- **JDK 8** or higher +- `.env` file configured (copy from project or create one) + +## Environment Variables + +Tests use the `dotenv` library to load database credentials from the `.env` file. Required variables: + +| Variable | Description | Default | +|----------|-------------|---------| +| `TEST_MYSQL_HOST` | MySQL/MariaDB host | `127.0.0.1` | +| `TEST_MYSQL_PORT` | MySQL/MariaDB port | `3306` | +| `TEST_MYSQL_PASSWORD` | MySQL root password | Must match `DB_PASS` | +| `TEST_MONGO_HOST` | MongoDB host | `127.0.0.1` | +| `TEST_MONGO_PORT` | MongoDB port | `27017` | + +**Important**: `TEST_MYSQL_PASSWORD` must match the `DB_PASS` value used when the database container was created. + +## Running Tests with Docker + +### Step 1: Start Database Containers + +Start only the database containers (not the web application): + +```bash +docker-compose up -d db mongo +``` + +Wait 15-30 seconds for the databases to initialize fully. + +### Step 2: Verify Containers Are Running + +```bash +docker-compose ps +``` + +You should see `secshep_mariadb` and `secshep_mongo` with status "Up". + +### Step 3: Run the Tests + +```bash +mvn test +``` + +### Step 4: Stop Containers + +When finished: + +```bash +docker-compose down +``` + +## Running Specific Tests + +Run a single test class: + +```bash +mvn test -Dtest=GetterTest +``` + +Run multiple test classes: + +```bash +mvn test -Dtest=GetterTest,SetterTest +``` + +Run tests matching a pattern: + +```bash +mvn test -Dtest=*Pool* +``` + +## Understanding Skipped Tests + +Some tests will show as "skipped" rather than failed. This happens when: + +1. **Database is not running** - Tests that require database connectivity are skipped +2. **Credentials mismatch** - `TEST_MYSQL_PASSWORD` doesn't match `DB_PASS` +3. **Connection refused** - Database container isn't ready yet + +This is intentional behavior. It allows basic unit tests to run even without a full database setup, while integration tests are skipped gracefully. + +### Example Output + +``` +Tests run: 16, Failures: 0, Errors: 0, Skipped: 11 +``` + +This indicates 5 tests ran successfully and 11 were skipped (likely database-dependent tests). + +## Troubleshooting + +### "Access denied for user 'root'" + +The password in your `.env` file doesn't match what the database was created with. Options: + +1. Update `TEST_MYSQL_PASSWORD` to match `DB_PASS` +2. Or recreate the database volume: + +```bash +docker-compose down -v +docker-compose up -d db mongo +``` + +### "Connection refused" + +The database container isn't running or isn't ready: + +```bash +# Check container status +docker-compose ps + +# Check database logs +docker-compose logs db + +# Wait and retry +sleep 30 +mvn test +``` + +### Tests Still Failing After Fix + +Make sure Maven picks up the latest code: + +```bash +mvn clean test +``` + +## Connection Pool Tests + +The connection pool tests (`ConnectionPoolTest`, `DatabaseLifecycleListenerTest`) verify the HikariCP connection pooling implementation. These tests: + +- **Require database connectivity** for full coverage +- **Skip gracefully** when database is unavailable +- **Always run** state-management tests that don't need a database + +See [database-configuration.md](database-configuration.md) for connection pool configuration details. diff --git a/pom.xml b/pom.xml index 684334bed..f18742cc4 100644 --- a/pom.xml +++ b/pom.xml @@ -114,6 +114,13 @@ 3.3.2 + + + com.zaxxer + HikariCP + 4.0.3 + + org.owasp.encoder diff --git a/src/it/java/dbProcs/ConnectionPoolIT.java b/src/it/java/dbProcs/ConnectionPoolIT.java new file mode 100644 index 000000000..ccead9262 --- /dev/null +++ b/src/it/java/dbProcs/ConnectionPoolIT.java @@ -0,0 +1,511 @@ +package dbProcs; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import com.zaxxer.hikari.HikariDataSource; +import java.io.IOException; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import testUtils.TestProperties; + +public class ConnectionPoolIT { + + private static final Logger log = LogManager.getLogger(ConnectionPoolIT.class); + private static boolean databaseAvailable = false; + + @BeforeAll + public static void setup() throws IOException, SQLException { + TestProperties.setTestPropertiesFileDirectory(log); + TestProperties.createMysqlResource(); + + try { + ConnectionPool.initialize(); + Connection conn = ConnectionPool.getConnection(); + conn.close(); + databaseAvailable = true; + log.info("Database is available - running full test suite"); + } catch (Exception e) { + databaseAvailable = false; + log.warn("Database not available - skipping connection-dependent tests: " + e.getMessage()); + } finally { + ConnectionPool.reset(); + } + } + + @AfterAll + public static void cleanup() { + ConnectionPool.shutdown(); + } + + @BeforeEach + public void resetBeforeEachTest() throws IOException { + TestProperties.createMysqlResource(); + ConnectionPool.reset(); + } + + private void requireDatabase() { + assumeTrue(databaseAvailable, "Database not available"); + } + + @Test + public void testPoolInitializationState() { + ConnectionPool.reset(); + assertFalse(ConnectionPool.isInitialized(), "Pool should not be initialized before first use"); + } + + @Test + public void testPoolInitialization() { + requireDatabase(); + + ConnectionPool.reset(); + assertFalse(ConnectionPool.isInitialized(), "Pool should not be initialized before first use"); + + ConnectionPool.initialize(); + assertTrue(ConnectionPool.isInitialized(), "Pool should be initialized after initialize()"); + } + + @Test + public void testConnectionAcquisition() throws SQLException { + requireDatabase(); + + if (!ConnectionPool.isInitialized()) { + ConnectionPool.initialize(); + } + + Connection conn = null; + try { + conn = ConnectionPool.getConnection(); + assertNotNull(conn, "Should be able to get a connection from the pool"); + assertFalse(conn.isClosed(), "Connection should not be closed"); + } finally { + if (conn != null) { + conn.close(); + } + } + } + + @Test + public void testConnectionReturn() throws SQLException { + requireDatabase(); + + if (!ConnectionPool.isInitialized()) { + ConnectionPool.initialize(); + } + + Connection conn = ConnectionPool.getConnection(); + assertNotNull(conn, "Should get a connection"); + + conn.close(); + assertTrue(conn.isClosed(), "Connection should appear closed after close()"); + + Connection conn2 = ConnectionPool.getConnection(); + assertNotNull(conn2, "Should be able to get another connection"); + conn2.close(); + } + + @Test + public void testPoolShutdown() throws SQLException { + requireDatabase(); + + ConnectionPool.reset(); + ConnectionPool.initialize(); + assertTrue(ConnectionPool.isInitialized(), "Pool should be initialized"); + + Connection conn = ConnectionPool.getConnection(); + assertNotNull(conn); + conn.close(); + + ConnectionPool.shutdown(); + assertFalse(ConnectionPool.isInitialized(), "Pool should not be initialized after shutdown"); + } + + @Test + public void testConcurrentConnections() throws InterruptedException { + requireDatabase(); + + ConnectionPool.reset(); + ConnectionPool.initialize(); + + final int numThreads = 10; + final int operationsPerThread = 5; + final CountDownLatch startLatch = new CountDownLatch(1); + final CountDownLatch doneLatch = new CountDownLatch(numThreads); + final AtomicInteger successCount = new AtomicInteger(0); + final AtomicInteger errorCount = new AtomicInteger(0); + + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + + for (int i = 0; i < numThreads; i++) { + executor.submit( + () -> { + try { + startLatch.await(); + for (int j = 0; j < operationsPerThread; j++) { + Connection conn = null; + try { + conn = ConnectionPool.getConnection(); + if (conn != null && !conn.isClosed()) { + successCount.incrementAndGet(); + } + Thread.sleep(10); + } catch (SQLException e) { + errorCount.incrementAndGet(); + log.error("Error getting connection: " + e.getMessage()); + } finally { + if (conn != null) { + try { + conn.close(); + } catch (SQLException e) { + log.warn("Error closing connection: " + e.getMessage()); + } + } + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + doneLatch.countDown(); + } + }); + } + + startLatch.countDown(); + + boolean completed = doneLatch.await(60, TimeUnit.SECONDS); + executor.shutdown(); + + assertTrue(completed, "All threads should complete within timeout"); + assertTrue(successCount.get() > 0, "Should have successful connections"); + log.info("Concurrent test: {} successful, {} errors", successCount.get(), errorCount.get()); + } + + @Test + public void testPoolConfiguration() throws SQLException { + requireDatabase(); + + ConnectionPool.reset(); + ConnectionPool.initialize(); + + String stats = ConnectionPool.getPoolStats(); + assertNotNull(stats, "Pool stats should not be null"); + assertTrue(stats.contains("Pool:"), "Stats should contain pool name"); + log.info("Pool stats: " + stats); + } + + @Test + public void testConnectionValidation() throws SQLException { + requireDatabase(); + + if (!ConnectionPool.isInitialized()) { + ConnectionPool.initialize(); + } + + Connection conn = ConnectionPool.getConnection(); + try { + assertNotNull(conn, "Connection should not be null"); + assertTrue(conn.isValid(5), "Connection should be valid"); + } finally { + conn.close(); + } + } + + @Test + public void testLazyInitialization() throws SQLException { + requireDatabase(); + + ConnectionPool.reset(); + assertFalse(ConnectionPool.isInitialized(), "Pool should not be initialized"); + + Connection conn = ConnectionPool.getConnection(); + assertNotNull(conn, "Should get a connection"); + assertTrue( + ConnectionPool.isInitialized(), "Pool should be initialized after getting connection"); + conn.close(); + } + + @Test + public void testChallengeConnectionAcquiresValidConnection() throws SQLException { + requireDatabase(); + + ConnectionPool.reset(); + ConnectionPool.initialize(); + + Connection conn = + ConnectionPool.getChallengeConnection( + getTestConnectionUrl(), "core", getTestDbOptions(), "root", getTestPassword()); + try { + assertNotNull(conn, "Challenge connection should not be null"); + assertFalse(conn.isClosed(), "Challenge connection should not be closed"); + assertTrue(conn.isValid(5), "Challenge connection should be valid"); + } finally { + conn.close(); + } + } + + @Test + public void testChallengePoolReusedForSameCredentials() throws SQLException { + requireDatabase(); + + ConnectionPool.reset(); + ConnectionPool.initialize(); + + assertEquals(0, ConnectionPool.getChallengePoolCount(), "Should start with no challenge pools"); + + // First call creates a pool + Connection conn1 = + ConnectionPool.getChallengeConnection( + getTestConnectionUrl(), "core", getTestDbOptions(), "root", getTestPassword()); + conn1.close(); + assertEquals(1, ConnectionPool.getChallengePoolCount(), "Should have one challenge pool"); + + // Second call with same credentials should reuse the pool + Connection conn2 = + ConnectionPool.getChallengeConnection( + getTestConnectionUrl(), "core", getTestDbOptions(), "root", getTestPassword()); + conn2.close(); + assertEquals(1, ConnectionPool.getChallengePoolCount(), "Should still have one pool (reused)"); + } + + @Test + public void testChallengePoolSeparatePerCredentials() throws SQLException { + requireDatabase(); + + ConnectionPool.reset(); + ConnectionPool.initialize(); + + // Create a pool with one set of credentials + Connection conn1 = + ConnectionPool.getChallengeConnection( + getTestConnectionUrl(), "core", getTestDbOptions(), "root", getTestPassword()); + conn1.close(); + assertEquals(1, ConnectionPool.getChallengePoolCount(), "Should have one challenge pool"); + + // Create a pool with different schema (simulates different challenge) + Connection conn2 = + ConnectionPool.getChallengeConnection( + getTestConnectionUrl(), + "information_schema", + getTestDbOptions(), + "root", + getTestPassword()); + conn2.close(); + assertEquals( + 2, ConnectionPool.getChallengePoolCount(), "Should have two pools for different schemas"); + } + + @Test + public void testChallengePoolHasCorrectSizing() throws SQLException { + requireDatabase(); + + ConnectionPool.reset(); + ConnectionPool.initialize(); + + String url = getTestConnectionUrl(); + String schema = "core"; + String options = getTestDbOptions(); + String username = "root"; + String password = getTestPassword(); + + Connection conn = + ConnectionPool.getChallengeConnection(url, schema, options, username, password); + conn.close(); + + // Build the pool key the same way ConnectionPool does + String jdbcUrl = url + schema + "?" + options; + String poolKey = jdbcUrl + ":" + username; + + HikariDataSource ds = ConnectionPool.getChallengePool(poolKey); + assertNotNull(ds, "Challenge pool should exist for key"); + assertEquals(3, ds.getMaximumPoolSize(), "Challenge pool maxPoolSize should be 3"); + assertEquals(0, ds.getMinimumIdle(), "Challenge pool minIdle should be 0"); + assertEquals(120000, ds.getIdleTimeout(), "Challenge pool idleTimeout should be 2 minutes"); + } + + @Test + public void testChallengePoolMaxConnectionsEnforced() throws SQLException { + requireDatabase(); + + ConnectionPool.reset(); + ConnectionPool.initialize(); + + // Acquire all 3 max connections + List connections = new ArrayList<>(); + try { + for (int i = 0; i < 3; i++) { + Connection conn = + ConnectionPool.getChallengeConnection( + getTestConnectionUrl(), "core", getTestDbOptions(), "root", getTestPassword()); + assertNotNull(conn, "Should get connection " + (i + 1)); + connections.add(conn); + } + + // The 4th connection should block and eventually timeout since pool max is 3 + // We can't easily test the timeout without waiting 30s, so instead verify + // the pool reports the right active count + String url = getTestConnectionUrl(); + String jdbcUrl = url + "core" + "?" + getTestDbOptions(); + String poolKey = jdbcUrl + ":" + "root"; + HikariDataSource ds = ConnectionPool.getChallengePool(poolKey); + assertEquals( + 3, ds.getHikariPoolMXBean().getActiveConnections(), "Should have 3 active connections"); + } finally { + for (Connection conn : connections) { + conn.close(); + } + } + } + + @Test + public void testShutdownClosesAllChallengePools() throws SQLException { + requireDatabase(); + + ConnectionPool.reset(); + ConnectionPool.initialize(); + + // Create two challenge pools + Connection conn1 = + ConnectionPool.getChallengeConnection( + getTestConnectionUrl(), "core", getTestDbOptions(), "root", getTestPassword()); + conn1.close(); + Connection conn2 = + ConnectionPool.getChallengeConnection( + getTestConnectionUrl(), + "information_schema", + getTestDbOptions(), + "root", + getTestPassword()); + conn2.close(); + assertEquals(2, ConnectionPool.getChallengePoolCount(), "Should have two challenge pools"); + + ConnectionPool.shutdown(); + assertEquals( + 0, ConnectionPool.getChallengePoolCount(), "Shutdown should clear all challenge pools"); + } + + @Test + public void testConcurrentChallengeConnections() throws InterruptedException { + requireDatabase(); + + ConnectionPool.reset(); + ConnectionPool.initialize(); + + final int numThreads = 6; + final int operationsPerThread = 5; + final CountDownLatch startLatch = new CountDownLatch(1); + final CountDownLatch doneLatch = new CountDownLatch(numThreads); + final AtomicInteger successCount = new AtomicInteger(0); + final AtomicInteger errorCount = new AtomicInteger(0); + + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + + for (int i = 0; i < numThreads; i++) { + executor.submit( + () -> { + try { + startLatch.await(); + for (int j = 0; j < operationsPerThread; j++) { + Connection conn = null; + try { + conn = + ConnectionPool.getChallengeConnection( + getTestConnectionUrl(), + "core", + getTestDbOptions(), + "root", + getTestPassword()); + if (conn != null && !conn.isClosed()) { + successCount.incrementAndGet(); + } + Thread.sleep(10); + } catch (SQLException e) { + errorCount.incrementAndGet(); + log.error("Error getting challenge connection: " + e.getMessage()); + } finally { + if (conn != null) { + try { + conn.close(); + } catch (SQLException e) { + log.warn("Error closing challenge connection: " + e.getMessage()); + } + } + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + doneLatch.countDown(); + } + }); + } + + startLatch.countDown(); + + boolean completed = doneLatch.await(60, TimeUnit.SECONDS); + executor.shutdown(); + + assertTrue(completed, "All threads should complete within timeout"); + assertEquals( + numThreads * operationsPerThread, + successCount.get(), + "All challenge connection attempts should succeed"); + assertEquals(0, errorCount.get(), "Should have no errors"); + assertEquals( + 1, + ConnectionPool.getChallengePoolCount(), + "Should have exactly one challenge pool (all threads used same credentials)"); + } + + @Test + public void testPoolShutdownWithoutInit() { + ConnectionPool.reset(); + assertFalse(ConnectionPool.isInitialized(), "Pool should not be initialized"); + + ConnectionPool.shutdown(); + + assertFalse( + ConnectionPool.isInitialized(), "Pool should still not be initialized after shutdown"); + } + + @Test + public void testResetIsIdempotent() { + ConnectionPool.reset(); + ConnectionPool.reset(); + ConnectionPool.reset(); + + assertFalse( + ConnectionPool.isInitialized(), "Pool should not be initialized after multiple resets"); + } + + private static String getTestConnectionUrl() { + io.github.cdimascio.dotenv.Dotenv dotenv = io.github.cdimascio.dotenv.Dotenv.load(); + String host = dotenv.get("TEST_MYSQL_HOST"); + String port = dotenv.get("TEST_MYSQL_PORT"); + return "jdbc:mariadb://" + host + ":" + port + "/"; + } + + private static String getTestPassword() { + io.github.cdimascio.dotenv.Dotenv dotenv = io.github.cdimascio.dotenv.Dotenv.load(); + return dotenv.get("TEST_MYSQL_PASSWORD"); + } + + private static String getTestDbOptions() { + return "useUnicode=true&character_set_server=utf8mb4"; + } +} diff --git a/src/it/java/dbProcs/MongoDatabaseIT.java b/src/it/java/dbProcs/MongoDatabaseIT.java index 379ea6ebb..59ca10fac 100644 --- a/src/it/java/dbProcs/MongoDatabaseIT.java +++ b/src/it/java/dbProcs/MongoDatabaseIT.java @@ -2,15 +2,25 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; import com.mongodb.MongoClient; import com.mongodb.MongoCredential; import java.io.IOException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; @@ -29,6 +39,16 @@ public static void initAll() throws IOException { TestProperties.createMongoResource(); } + @BeforeEach + public void setupEach() { + MongoDatabase.resetInstance(); + } + + @AfterEach + public void teardownEach() { + MongoDatabase.resetInstance(); + } + @Test @DisplayName("Should Return Type MongoCredentials") public void getMongoChallengeCredentials_ShouldReturnTypeMongoCredentials() throws IOException { @@ -77,12 +97,119 @@ public void getMongoDbConnection_ShouldReturnTypeMongoClient() { @Test @DisplayName("Must return type (Mongo) DB") @Disabled( - "Fongo 2.1.0 is incompatible with mongo-java-driver 3.12.14 (NPE in createOperationExecutor)") + "Fongo 2.1.0 is incompatible with mongo-java-driver 3.12.14 (NPE in" + + " createOperationExecutor)") public void getMongoDatabase_ShouldReturnTypeDB() {} @Test @DisplayName("Read properties file for db name") @Disabled( - "Fongo 2.1.0 is incompatible with mongo-java-driver 3.12.14 (NPE in createOperationExecutor)") + "Fongo 2.1.0 is incompatible with mongo-java-driver 3.12.14 (NPE in" + + " createOperationExecutor)") public void getMongoDatabase_ReadDbName() {} + + // ============= Singleton Pattern Tests ============= + + @Test + @DisplayName("Should return same MongoClient instance on multiple calls (singleton)") + public void getMongoDbConnection_ShouldReturnSingletonInstance() { + MongoClient client1 = MongoDatabase.getMongoDbConnection(null); + MongoClient client2 = MongoDatabase.getMongoDbConnection(null); + + assertNotNull(client1, "First MongoClient should not be null"); + assertNotNull(client2, "Second MongoClient should not be null"); + assertSame(client1, client2, "Multiple calls should return the same MongoClient instance"); + } + + @Test + @DisplayName("Should return same MongoClient for same credentials (singleton)") + public void getMongoDbConnection_WithCredentials_ShouldReturnSingletonPerCredential() { + MongoCredential credential1 = + MongoCredential.createScramSha1Credential("user1", "db1", "pass1".toCharArray()); + MongoCredential credential2 = + MongoCredential.createScramSha1Credential("user1", "db1", "pass1".toCharArray()); + + MongoClient client1 = MongoDatabase.getMongoDbConnection(null, credential1); + MongoClient client2 = MongoDatabase.getMongoDbConnection(null, credential2); + + assertNotNull(client1, "First MongoClient should not be null"); + assertNotNull(client2, "Second MongoClient should not be null"); + assertSame(client1, client2, "Same credentials should return the same MongoClient instance"); + } + + @Test + @DisplayName("Should track initialization state correctly") + public void isInitialized_ShouldTrackState() { + assertFalse(MongoDatabase.isInitialized(), "Should not be initialized before first use"); + + MongoDatabase.getMongoDbConnection(null); + + assertTrue(MongoDatabase.isInitialized(), "Should be initialized after first use"); + + MongoDatabase.resetInstance(); + assertFalse(MongoDatabase.isInitialized(), "Should not be initialized after reset"); + } + + @Test + @DisplayName("Should handle concurrent singleton initialization safely") + public void getMongoDbConnection_ShouldBeThreadSafe() throws InterruptedException { + final int numThreads = 10; + final CountDownLatch startLatch = new CountDownLatch(1); + final CountDownLatch doneLatch = new CountDownLatch(numThreads); + final AtomicReference firstClient = new AtomicReference<>(); + final java.util.concurrent.atomic.AtomicBoolean allSame = + new java.util.concurrent.atomic.AtomicBoolean(true); + + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + + for (int i = 0; i < numThreads; i++) { + executor.submit( + () -> { + try { + startLatch.await(); + MongoClient client = MongoDatabase.getMongoDbConnection(null); + + if (!firstClient.compareAndSet(null, client)) { + if (firstClient.get() != client) { + allSame.set(false); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } finally { + doneLatch.countDown(); + } + }); + } + + startLatch.countDown(); + + boolean completed = doneLatch.await(30, TimeUnit.SECONDS); + executor.shutdown(); + + assertTrue(completed, "All threads should complete within timeout"); + assertTrue(allSame.get(), "All threads should get the same MongoClient instance"); + } + + @Test + @DisplayName("Shutdown should close all connections") + public void shutdown_ShouldCloseConnections() { + MongoDatabase.getMongoDbConnection(null); + assertTrue(MongoDatabase.isInitialized(), "Should be initialized"); + + MongoDatabase.shutdown(); + + assertFalse(MongoDatabase.isInitialized(), "Should not be initialized after shutdown"); + } + + @Test + @DisplayName("Reset should be equivalent to shutdown") + public void resetInstance_ShouldBeEquivalentToShutdown() { + MongoDatabase.getMongoDbConnection(null); + assertTrue(MongoDatabase.isInitialized(), "Should be initialized"); + + MongoDatabase.resetInstance(); + + assertFalse(MongoDatabase.isInitialized(), "Should not be initialized after reset"); + } } diff --git a/src/main/java/dbProcs/ConnectionPool.java b/src/main/java/dbProcs/ConnectionPool.java new file mode 100644 index 000000000..e92d663aa --- /dev/null +++ b/src/main/java/dbProcs/ConnectionPool.java @@ -0,0 +1,445 @@ +package dbProcs; + +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * Connection pool manager using HikariCP for efficient database connection management. + * + *

This class provides connection pooling for MySQL/MariaDB databases, significantly improving + * performance by reusing connections instead of creating new ones for each request. + * + *

This file is part of the Security Shepherd Project. + * + *

The Security Shepherd project is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free Software Foundation, either + * version 3 of the License, or (at your option) any later version. + * + * @author Paul + */ +public class ConnectionPool { + + private static final Logger log = LogManager.getLogger(ConnectionPool.class); + + // Singleton instance for the core database pool + private static volatile HikariDataSource coreDataSource; + + // Lock object for thread-safe initialization + private static final Object lock = new Object(); + + // Challenge-specific connection pools (keyed by schema name) + private static final ConcurrentHashMap challengePools = + new ConcurrentHashMap<>(); + + // Default pool configuration values (used by core pool) + private static final int DEFAULT_MAX_POOL_SIZE = 10; + private static final int DEFAULT_MIN_IDLE = 2; + private static final long DEFAULT_CONNECTION_TIMEOUT = 30000; // 30 seconds + private static final long DEFAULT_IDLE_TIMEOUT = 600000; // 10 minutes + private static final long DEFAULT_MAX_LIFETIME = 1800000; // 30 minutes + private static final long DEFAULT_LEAK_DETECTION_THRESHOLD = 60000; // 60 seconds + + // Challenge pool configuration values (smaller footprint per schema) + private static final int CHALLENGE_MAX_POOL_SIZE = 3; + private static final int CHALLENGE_MIN_IDLE = 0; + private static final long CHALLENGE_IDLE_TIMEOUT = 120000; // 2 minutes + + // Flag to track if pool has been initialized + private static volatile boolean initialized = false; + + /** Private constructor to prevent instantiation */ + private ConnectionPool() {} + + /** + * Initialize the connection pool. This method is idempotent - calling it multiple times has no + * effect after the first successful initialization. + */ + public static void initialize() { + if (!initialized) { + synchronized (lock) { + if (!initialized) { + try { + coreDataSource = createCoreDataSource(); + initialized = true; + log.info("Connection pool initialized successfully"); + } catch (Exception e) { + log.error("Failed to initialize connection pool: " + e.getMessage(), e); + throw new RuntimeException("Failed to initialize connection pool", e); + } + } + } + } + } + + /** + * Creates the HikariDataSource for the core database. + * + * @return Configured HikariDataSource + */ + private static HikariDataSource createCoreDataSource() { + Properties prop = loadDatabaseProperties(); + + String connectionURL = prop.getProperty("databaseConnectionURL"); + String databaseSchema = prop.getProperty("databaseSchema"); + String dbOptions = prop.getProperty("databaseOptions"); + String username = prop.getProperty("databaseUsername"); + String password = prop.getProperty("databasePassword"); + + // Build the full JDBC URL + String jdbcUrl = connectionURL + databaseSchema; + if (dbOptions != null && !dbOptions.isEmpty()) { + jdbcUrl += "?" + dbOptions; + } + + return createDataSource(jdbcUrl, username, password, prop, "CorePool"); + } + + /** + * Creates a HikariDataSource with the given configuration. + * + * @param jdbcUrl The JDBC URL + * @param username Database username + * @param password Database password + * @param prop Properties containing pool configuration + * @param poolName Name for the pool (for logging/monitoring) + * @return Configured HikariDataSource + */ + private static HikariDataSource createDataSource( + String jdbcUrl, String username, String password, Properties prop, String poolName) { + + HikariConfig config = new HikariConfig(); + + config.setJdbcUrl(jdbcUrl); + config.setUsername(username); + config.setPassword(password); + config.setPoolName(poolName); + + String driverClassName = prop.getProperty("DriverType"); + if (driverClassName == null || driverClassName.isEmpty()) { + if (jdbcUrl.startsWith("jdbc:mariadb:")) { + driverClassName = "org.mariadb.jdbc.Driver"; + } else if (jdbcUrl.startsWith("jdbc:mysql:")) { + driverClassName = "com.mysql.cj.jdbc.Driver"; + } else { + throw new IllegalArgumentException("Unsupported JDBC URL: " + jdbcUrl); + } + } + config.setDriverClassName(driverClassName); + + // Pool size configuration + config.setMaximumPoolSize(getIntProperty(prop, "pool.maximumPoolSize", DEFAULT_MAX_POOL_SIZE)); + config.setMinimumIdle(getIntProperty(prop, "pool.minimumIdle", DEFAULT_MIN_IDLE)); + + // Timeout configuration + config.setConnectionTimeout( + getLongProperty(prop, "pool.connectionTimeout", DEFAULT_CONNECTION_TIMEOUT)); + config.setIdleTimeout(getLongProperty(prop, "pool.idleTimeout", DEFAULT_IDLE_TIMEOUT)); + config.setMaxLifetime(getLongProperty(prop, "pool.maxLifetime", DEFAULT_MAX_LIFETIME)); + + // Leak detection - logs a warning if a connection is held longer than this threshold + config.setLeakDetectionThreshold( + getLongProperty(prop, "pool.leakDetectionThreshold", DEFAULT_LEAK_DETECTION_THRESHOLD)); + + // Connection validation + config.setConnectionTestQuery("SELECT 1"); + + // Performance optimizations + config.addDataSourceProperty("cachePrepStmts", "true"); + config.addDataSourceProperty("prepStmtCacheSize", "250"); + config.addDataSourceProperty("prepStmtCacheSqlLimit", "2048"); + config.addDataSourceProperty("useServerPrepStmts", "true"); + + log.debug( + "Creating HikariCP pool '{}' with maxPoolSize={}, minIdle={}", + poolName, + config.getMaximumPoolSize(), + config.getMinimumIdle()); + + return new HikariDataSource(config); + } + + /** + * Creates a HikariDataSource with custom pool size settings. Used for challenge pools which need + * a smaller resource footprint than the core pool. + * + * @param jdbcUrl The JDBC URL + * @param username Database username + * @param password Database password + * @param prop Properties containing pool configuration + * @param poolName Name for the pool (for logging/monitoring) + * @param maxPoolSize Maximum number of connections in the pool + * @param minIdle Minimum number of idle connections maintained + * @param idleTimeout Idle timeout in milliseconds before a connection is retired + * @return Configured HikariDataSource + */ + private static HikariDataSource createDataSource( + String jdbcUrl, + String username, + String password, + Properties prop, + String poolName, + int maxPoolSize, + int minIdle, + long idleTimeout) { + + HikariConfig config = new HikariConfig(); + + config.setJdbcUrl(jdbcUrl); + config.setUsername(username); + config.setPassword(password); + config.setPoolName(poolName); + + String driverClassName = prop.getProperty("DriverType"); + if (driverClassName == null || driverClassName.isEmpty()) { + if (jdbcUrl.startsWith("jdbc:mariadb:")) { + driverClassName = "org.mariadb.jdbc.Driver"; + } else if (jdbcUrl.startsWith("jdbc:mysql:")) { + driverClassName = "com.mysql.cj.jdbc.Driver"; + } else { + throw new IllegalArgumentException("Unsupported JDBC URL: " + jdbcUrl); + } + } + config.setDriverClassName(driverClassName); + + // Pool size configuration (using provided values instead of defaults) + config.setMaximumPoolSize(maxPoolSize); + config.setMinimumIdle(minIdle); + + // Timeout configuration + config.setConnectionTimeout( + getLongProperty(prop, "pool.connectionTimeout", DEFAULT_CONNECTION_TIMEOUT)); + config.setIdleTimeout(idleTimeout); + config.setMaxLifetime(getLongProperty(prop, "pool.maxLifetime", DEFAULT_MAX_LIFETIME)); + + // Leak detection - logs a warning if a connection is held longer than this threshold + config.setLeakDetectionThreshold( + getLongProperty(prop, "pool.leakDetectionThreshold", DEFAULT_LEAK_DETECTION_THRESHOLD)); + + // Connection validation + config.setConnectionTestQuery("SELECT 1"); + + // Performance optimizations + config.addDataSourceProperty("cachePrepStmts", "true"); + config.addDataSourceProperty("prepStmtCacheSize", "250"); + config.addDataSourceProperty("prepStmtCacheSqlLimit", "2048"); + config.addDataSourceProperty("useServerPrepStmts", "true"); + + log.debug( + "Creating HikariCP pool '{}' with maxPoolSize={}, minIdle={}, idleTimeout={}", + poolName, + config.getMaximumPoolSize(), + config.getMinimumIdle(), + config.getIdleTimeout()); + + return new HikariDataSource(config); + } + + /** + * Loads database properties from the configuration file. + * + * @return Properties object containing database configuration + */ + private static Properties loadDatabaseProperties() { + Properties prop = new Properties(); + String mysqlProps = Constants.MYSQL_DB_PROP; + + try (InputStream input = new FileInputStream(mysqlProps)) { + prop.load(input); + } catch (IOException e) { + log.error("Could not load database properties file: " + e.getMessage(), e); + throw new RuntimeException("Could not load database properties", e); + } + + return prop; + } + + /** + * Gets a connection from the core database pool. + * + * @return A connection from the pool + * @throws SQLException if a connection cannot be obtained + */ + public static Connection getConnection() throws SQLException { + if (!initialized) { + try { + initialize(); + } catch (RuntimeException e) { + throw new SQLException("Connection pool not available", e); + } + } + return coreDataSource.getConnection(); + } + + /** + * Gets a connection for a specific challenge schema. + * + * @param connectionURL The base connection URL + * @param challengeConnectionURL The challenge-specific part of the URL + * @param dbOptions Database options + * @param username Database username + * @param password Database password + * @return A connection from the challenge-specific pool + * @throws SQLException if a connection cannot be obtained + */ + public static Connection getChallengeConnection( + String connectionURL, + String challengeConnectionURL, + String dbOptions, + String username, + String password) + throws SQLException { + + // Build the full JDBC URL + final String jdbcUrl; + if (dbOptions != null && !dbOptions.isEmpty()) { + jdbcUrl = connectionURL + challengeConnectionURL + "?" + dbOptions; + } else { + jdbcUrl = connectionURL + challengeConnectionURL; + } + + // Use the full URL and username as the pool key + String poolKey = jdbcUrl + ":" + username; + + HikariDataSource dataSource = + challengePools.computeIfAbsent( + poolKey, + key -> { + Properties prop = loadDatabaseProperties(); + return createDataSource( + jdbcUrl, + username, + password, + prop, + "ChallengePool-" + username, + CHALLENGE_MAX_POOL_SIZE, + CHALLENGE_MIN_IDLE, + CHALLENGE_IDLE_TIMEOUT); + }); + + return dataSource.getConnection(); + } + + /** + * Shuts down all connection pools. This should be called when the application is shutting down. + */ + public static void shutdown() { + synchronized (lock) { + log.info("Shutting down connection pools..."); + + // Close core pool + if (coreDataSource != null && !coreDataSource.isClosed()) { + coreDataSource.close(); + log.debug("Core connection pool closed"); + } + + // Close all challenge pools + for (HikariDataSource ds : challengePools.values()) { + if (!ds.isClosed()) { + ds.close(); + } + } + challengePools.clear(); + + coreDataSource = null; + initialized = false; + + log.info("All connection pools shut down successfully"); + } + } + + /** + * Checks if the connection pool has been initialized. + * + * @return true if initialized, false otherwise + */ + public static boolean isInitialized() { + return initialized; + } + + /** + * Gets pool statistics for monitoring (useful for debugging). + * + * @return String containing pool statistics + */ + public static String getPoolStats() { + if (coreDataSource == null) { + return "Pool not initialized"; + } + + return String.format( + "Pool: %s, Active: %d, Idle: %d, Total: %d, Waiting: %d", + coreDataSource.getPoolName(), + coreDataSource.getHikariPoolMXBean().getActiveConnections(), + coreDataSource.getHikariPoolMXBean().getIdleConnections(), + coreDataSource.getHikariPoolMXBean().getTotalConnections(), + coreDataSource.getHikariPoolMXBean().getThreadsAwaitingConnection()); + } + + /** Helper method to get an integer property with a default value. */ + private static int getIntProperty(Properties prop, String key, int defaultValue) { + String value = prop.getProperty(key); + if (value != null) { + try { + return Integer.parseInt(value); + } catch (NumberFormatException e) { + log.warn( + "Invalid integer value for property '{}': {}, using default: {}", + key, + value, + defaultValue); + } + } + return defaultValue; + } + + /** Helper method to get a long property with a default value. */ + private static long getLongProperty(Properties prop, String key, long defaultValue) { + String value = prop.getProperty(key); + if (value != null) { + try { + return Long.parseLong(value); + } catch (NumberFormatException e) { + log.warn( + "Invalid long value for property '{}': {}, using default: {}", + key, + value, + defaultValue); + } + } + return defaultValue; + } + + /** Resets the pool state. This is primarily for testing purposes. */ + public static void reset() { + shutdown(); + } + + /** + * Returns the number of challenge pools currently active. Intended for testing and monitoring. + * + * @return the number of challenge pools + */ + public static int getChallengePoolCount() { + return challengePools.size(); + } + + /** + * Returns the HikariDataSource for a challenge pool by key, or null if not found. Intended for + * testing to verify pool configuration. + * + * @param poolKey the pool key (jdbcUrl:username) + * @return the HikariDataSource, or null + */ + static HikariDataSource getChallengePool(String poolKey) { + return challengePools.get(poolKey); + } +} diff --git a/src/main/java/dbProcs/Database.java b/src/main/java/dbProcs/Database.java index 662eec024..d69e60bb6 100644 --- a/src/main/java/dbProcs/Database.java +++ b/src/main/java/dbProcs/Database.java @@ -5,15 +5,14 @@ import java.io.IOException; import java.io.InputStream; import java.sql.Connection; -import java.sql.DriverManager; import java.sql.SQLException; import java.util.Properties; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; /** - * Used to create database connections using the FileInputProperties.readfile method to gather - * property information Initiated by Getter.java, Setter.java
+ * Used to create database connections using connection pooling via HikariCP. Connections are + * obtained from the pool and returned when closed. Initiated by Getter.java, Setter.java
*
* This file is part of the Security Shepherd Project. * @@ -29,16 +28,23 @@ * Shepherd project. If not, see . * * @author Mark + * @author Paul (connection pooling) */ public class Database { private static final Logger log = LogManager.getLogger(Database.class); /** - * This method is used by the application to open an connection to the database + * This method is used by the application to get a connection for challenge schemas. Connections + * are obtained from a pool specific to the challenge. * - * @param conn The connection to close - * @throws SQLException + * @param driverType The JDBC driver type (kept for API compatibility, but not used with pooling) + * @param connectionURL The base connection URL + * @param dbOptions Database connection options + * @param dbUsername Database username + * @param dbPassword Database password + * @return A pooled connection + * @throws SQLException if a connection cannot be obtained */ public static Connection getConnection( String driverType, @@ -48,33 +54,29 @@ public static Connection getConnection( String dbPassword) throws SQLException { - try { - Class.forName(driverType).newInstance(); - } catch (InstantiationException | IllegalAccessException | ClassNotFoundException e) { - throw new RuntimeException(e); - } - - if (dbOptions.length() > 0) { - connectionURL += "?" + dbOptions; - } - - Connection conn = DriverManager.getConnection(connectionURL, dbUsername, dbPassword); - - return conn; + // Extract the schema portion from the connectionURL for challenge connections + // The connectionURL at this point already includes the schema + return ConnectionPool.getChallengeConnection( + "", // Base URL is already included in connectionURL + connectionURL, + dbOptions, + dbUsername, + dbPassword); } /** - * This method is used by the application to close an open connection to a database server + * This method is used by the application to close/return a connection to the pool. With + * connection pooling, this returns the connection to the pool for reuse. * - * @param conn The connection to close + * @param conn The connection to return to the pool */ public static void closeConnection(Connection conn) { - - // log.debug("Closing database connection"); - try { - conn.close(); - } catch (SQLException e) { - throw new RuntimeException(e); + if (conn != null) { + try { + conn.close(); // With HikariCP, this returns the connection to the pool + } catch (SQLException e) { + log.warn("Error returning connection to pool: " + e.getMessage()); + } } } @@ -181,68 +183,19 @@ public static Connection getChallengeConnection(String ApplicationRoot, String p } public static Connection getCoreConnection() throws SQLException, IOException { - Connection conn = getCoreConnection(""); - - return conn; + return getCoreConnection(""); } /** - * @param ApplicationRoot @return Connection to core schema with admin privileges @throws - * FileNotFoundException @throws SQLException @throws - * @throws IOException Returns connection to core schema in database - * @throws FileNotFoundException - * @throws SQLException - * @throws RuntimeException + * Gets a connection to the core database schema from the connection pool. + * + * @param ApplicationRoot The running context of the application (kept for API compatibility) + * @return Connection to core schema from the pool + * @throws SQLException if a connection cannot be obtained */ public static Connection getCoreConnection(String ApplicationRoot) throws SQLException { - Connection conn = null; - Properties prop = new Properties(); - - // Pull Driver and DB URL out of database.properties - - String mysql_props = Constants.MYSQL_DB_PROP; - - try (InputStream mysql_input = new FileInputStream(mysql_props)) { - - prop.load(mysql_input); - - } catch (IOException e) { - log.error("Could not load properties file: " + e.toString()); - throw new RuntimeException(e); - } - - String errorBase = "Missing property :"; - - String connectionURL = prop.getProperty("databaseConnectionURL"); - if (connectionURL == null) { - throw new RuntimeException(errorBase + "connectionURL"); - } - String databaseSchema = prop.getProperty("databaseSchema"); - if (databaseSchema == null) { - throw new RuntimeException(errorBase + "databaseSchema"); - } - String dbOptions = prop.getProperty("databaseOptions"); - if (dbOptions == null) { - throw new RuntimeException(errorBase + "databaseOptions"); - } - String driverType = prop.getProperty("DriverType"); - if (driverType == null) { - throw new RuntimeException(errorBase + "DriverType"); - } - String username = prop.getProperty("databaseUsername"); - if (username == null) { - throw new RuntimeException(errorBase + "databaseUsername"); - } - String password = prop.getProperty("databasePassword"); - if (password == null) { - throw new RuntimeException(errorBase + "databasePassword"); - } - - connectionURL += databaseSchema; - - conn = getConnection(driverType, connectionURL, dbOptions, username, password); - - return conn; + // Use the connection pool for core connections + return ConnectionPool.getConnection(); } public static Connection getDatabaseConnection(String ApplicationRoot) diff --git a/src/main/java/dbProcs/MongoDatabase.java b/src/main/java/dbProcs/MongoDatabase.java index a8d7c2965..7e1c6adb9 100644 --- a/src/main/java/dbProcs/MongoDatabase.java +++ b/src/main/java/dbProcs/MongoDatabase.java @@ -13,20 +13,23 @@ import com.mongodb.MongoTimeoutException; import com.mongodb.ServerAddress; import java.io.File; +import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Paths; -import java.util.Collections; +import java.util.Arrays; import java.util.Objects; import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.io.FileUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; /** - * Used to create MongoDb connections
+ * Used to create MongoDb connections using a singleton pattern for connection reuse. MongoClient + * already includes internal connection pooling, so we only need one instance.
*
* This file is part of the Security Shepherd Project. * @@ -47,14 +50,38 @@ public class MongoDatabase { private static final Logger log = LogManager.getLogger(MongoDatabase.class); + // Singleton instance for the base MongoClient (no credentials) + private static volatile MongoClient baseMongoClient; + + // Map of credential-based MongoClients (keyed by credential source/database) + private static final ConcurrentHashMap credentialClients = + new ConcurrentHashMap<>(); + + // Lock object for thread-safe initialization + private static final Object lock = new Object(); + + // Default pool configuration + private static final int DEFAULT_CONNECTIONS_PER_HOST = 10; + private static final int DEFAULT_MIN_CONNECTIONS_PER_HOST = 2; + /** - * Method to close a MongoDb connection + * Method to close a MongoDb connection. Note: With singleton pattern, individual connections + * should NOT be closed by callers. This method is kept for API compatibility but does nothing for + * singleton clients. Use shutdown() to close all connections when the application is shutting + * down. * - * @param conn The connection to close + * @param conn The connection (ignored for singleton clients) + * @deprecated Use shutdown() instead when shutting down the application */ + @Deprecated public static void closeConnection(MongoClient conn) { - - conn.close(); + // Do not close singleton connections - they are managed by the pool + // Only close if it's not one of our managed singletons + if (conn != null && conn != baseMongoClient && !credentialClients.containsValue(conn)) { + conn.close(); + } else { + log.debug("Ignoring close request for singleton MongoClient - use shutdown() instead"); + } } /** @@ -149,24 +176,32 @@ public static String getMongoChallengeCollName(String ApplicationRoot, String pa } /** - * Method to get a MongoDb Connection + * Method to get a singleton MongoDb Connection. MongoClient has internal connection pooling, so + * the same instance is reused. * - * @return A MongoDb Connection @throws IOException @throws + * @param ApplicationRoot The running context of the application (kept for API compatibility) + * @return A singleton MongoDb Connection */ public static MongoClient getMongoDbConnection(String ApplicationRoot) { - Properties prop = new Properties(); - - // Mongo DB URL from mongo.properties - String mongo_props = Constants.MONGO_DB_PROP; - - try (InputStream mongo_input = Files.newInputStream(Paths.get(mongo_props))) { - - prop.load(mongo_input); - - } catch (IOException e) { - log.error("Could not load properties file: " + e.toString()); - throw new RuntimeException(e); + if (baseMongoClient == null) { + synchronized (lock) { + if (baseMongoClient == null) { + baseMongoClient = createMongoClient(null); + log.info("Created singleton MongoClient instance"); + } + } } + return baseMongoClient; + } + + /** + * Creates a MongoClient with the configuration from properties file. + * + * @param credential Optional credential for authenticated connections + * @return A new MongoClient instance + */ + private static MongoClient createMongoClient(MongoCredential credential) { + Properties prop = loadMongoProperties(); String errorBase = "Missing property :"; @@ -186,23 +221,39 @@ public static MongoClient getMongoDbConnection(String ApplicationRoot) { if (socketTimeout == null) { throw new RuntimeException(errorBase + "socketTimeout"); } - String serverSelectionTimeout = prop.getProperty("serverSelectionTimeout"); if (serverSelectionTimeout == null) { throw new RuntimeException(errorBase + "serverSelectionTimeout"); } + // Configure connection pool options MongoClientOptions.Builder optionsBuilder = MongoClientOptions.builder(); optionsBuilder.connectTimeout(Integer.parseInt(connectTimeout)); optionsBuilder.socketTimeout(Integer.parseInt(socketTimeout)); optionsBuilder.serverSelectionTimeout(Integer.parseInt(serverSelectionTimeout)); - MongoClientOptions mongoOptions = optionsBuilder.build(); - try (MongoClient mongoClient = - new MongoClient( - new ServerAddress(connectionHost, Integer.parseInt(connectionPort)), mongoOptions)) { + // Connection pool settings + int connectionsPerHost = + getIntProperty(prop, "pool.connectionsPerHost", DEFAULT_CONNECTIONS_PER_HOST); + int minConnectionsPerHost = + getIntProperty(prop, "pool.minConnectionsPerHost", DEFAULT_MIN_CONNECTIONS_PER_HOST); + optionsBuilder.connectionsPerHost(connectionsPerHost); + optionsBuilder.minConnectionsPerHost(minConnectionsPerHost); - log.debug("Mongo Client: " + mongoClient); + MongoClientOptions mongoOptions = optionsBuilder.build(); + ServerAddress serverAddress = + new ServerAddress(connectionHost, Integer.parseInt(connectionPort)); + + try { + MongoClient mongoClient; + if (credential != null) { + mongoClient = new MongoClient(serverAddress, Arrays.asList(credential), mongoOptions); + log.debug("Created MongoClient with credentials for: " + credential.getSource()); + } else { + mongoClient = new MongoClient(serverAddress, mongoOptions); + log.debug("Created MongoClient without credentials"); + } + log.debug("Connection Host: " + connectionHost + ", Port: " + connectionPort); return mongoClient; } catch (NumberFormatException e) { @@ -211,7 +262,6 @@ public static MongoClient getMongoDbConnection(String ApplicationRoot) { } catch (MongoSocketOpenException e) { log.fatal("Mongo Doesn't seem to be running: " + e); - e.printStackTrace(); throw new RuntimeException(e); } catch (MongoSocketException e) { @@ -225,91 +275,65 @@ public static MongoClient getMongoDbConnection(String ApplicationRoot) { } /** - * Method to get a MongoDb Connection + * Loads MongoDB properties from the configuration file. * - * @param credential to connect to the MongoDB - * @return A MongoDb Connection + * @return Properties object containing MongoDB configuration */ - public static MongoClient getMongoDbConnection( - String ApplicationRoot, MongoCredential credential) { - + private static Properties loadMongoProperties() { Properties prop = new Properties(); - MongoClient mongoClient = null; - - // Mongo DB URL from mongo.properties String mongo_props = Constants.MONGO_DB_PROP; - try (InputStream mongo_input = Files.newInputStream(Paths.get(mongo_props))) { - + try (InputStream mongo_input = new FileInputStream(mongo_props)) { prop.load(mongo_input); - } catch (IOException e) { log.error("Could not load properties file: " + e.toString()); throw new RuntimeException(e); } - String errorBase = "Missing property :"; - - String connectionHost = prop.getProperty("connectionHost"); - if (connectionHost == null) { - throw new RuntimeException(errorBase + "connectionHost"); - } - String connectionPort = prop.getProperty("connectionPort"); - if (connectionPort == null) { - throw new RuntimeException(errorBase + "connectionPort"); - } - String connectTimeout = prop.getProperty("connectTimeout"); - if (connectTimeout == null) { - throw new RuntimeException(errorBase + "connectTimeout"); - } - String socketTimeout = prop.getProperty("socketTimeout"); - if (socketTimeout == null) { - throw new RuntimeException(errorBase + "socketTimeout"); - } + return prop; + } - String serverSelectionTimeout = prop.getProperty("serverSelectionTimeout"); - if (serverSelectionTimeout == null) { - throw new RuntimeException(errorBase + "serverSelectionTimeout"); + /** Helper method to get an integer property with a default value. */ + private static int getIntProperty(Properties prop, String key, int defaultValue) { + String value = prop.getProperty(key); + if (value != null) { + try { + return Integer.parseInt(value); + } catch (NumberFormatException e) { + log.warn( + "Invalid integer value for property '{}': {}, using default: {}", + key, + value, + defaultValue); + } } + return defaultValue; + } - MongoClientOptions.Builder optionsBuilder = MongoClientOptions.builder(); - optionsBuilder.connectTimeout(Integer.parseInt(connectTimeout)); - optionsBuilder.socketTimeout(Integer.parseInt(socketTimeout)); - optionsBuilder.serverSelectionTimeout(Integer.parseInt(serverSelectionTimeout)); - MongoClientOptions mongoOptions = optionsBuilder.build(); - - try { - mongoClient = - new MongoClient( - new ServerAddress(connectionHost, Integer.parseInt(connectionPort)), - credential, - mongoOptions); - - log.debug("Connection Host: " + connectionHost); - log.debug("Connection Port: " + Integer.parseInt(connectionPort)); - log.debug("Connection Creds: " + Collections.singletonList(credential)); - } catch (NumberFormatException e) { - log.fatal("The port in the properties file is not a number: " + e); - throw new RuntimeException(e); - - } catch (MongoSocketException | MongoTimeoutException e) { - log.fatal("Unable to get Mongodb connection (Is it on?): " + e); - throw new RuntimeException(e); - - } catch (MongoException e) { - log.fatal("Something went wrong with Mongo: " + e); - e.printStackTrace(); - throw new RuntimeException(e); + /** + * Method to get a singleton MongoDb Connection with credentials. Each unique credential gets its + * own MongoClient instance (with internal pooling). + * + * @param ApplicationRoot The running context of the application (kept for API compatibility) + * @param credential The credential to connect to MongoDB + * @return A singleton MongoDb Connection for the given credential + */ + public static MongoClient getMongoDbConnection( + String ApplicationRoot, MongoCredential credential) { - } catch (Exception e) { - log.fatal("Something went wrong: " + e); - e.printStackTrace(); - throw new RuntimeException(e); + if (credential == null) { + return getMongoDbConnection(ApplicationRoot); } - log.debug("Mongo Client: " + mongoClient); + // Use the credential source (database name) as the key + String credentialKey = credential.getSource() + ":" + credential.getUserName(); - return mongoClient; + return credentialClients.computeIfAbsent( + credentialKey, + key -> { + log.debug("Creating new MongoClient for credential: " + credentialKey); + return createMongoClient(credential); + }); } /** @@ -321,19 +345,7 @@ public static MongoClient getMongoDbConnection( public static DB getMongoDatabase(MongoClient mongoClient) { DB mongoDb = null; - Properties prop = new Properties(); - - // Mongo DB URL from mongo.properties - String mongo_props = Constants.MONGO_DB_PROP; - - try (InputStream mongo_input = Files.newInputStream(Paths.get(mongo_props))) { - - prop.load(mongo_input); - - } catch (IOException e) { - log.error("Could not load properties file: " + e.toString()); - throw new RuntimeException(e); - } + Properties prop = loadMongoProperties(); String dbname = prop.getProperty("databaseName"); if (dbname == null) { @@ -373,4 +385,43 @@ public static void executeMongoScript(File file, MongoClient mongoClient) throws log.debug("Mongo Result: " + result); } + + /** + * Shuts down all MongoDB connections. This should be called when the application is shutting + * down. + */ + public static void shutdown() { + synchronized (lock) { + log.info("Shutting down MongoDB connections..."); + + // Close base client + if (baseMongoClient != null) { + baseMongoClient.close(); + baseMongoClient = null; + log.debug("Base MongoClient closed"); + } + + // Close all credential-based clients + for (MongoClient client : credentialClients.values()) { + client.close(); + } + credentialClients.clear(); + + log.info("All MongoDB connections shut down successfully"); + } + } + + /** Resets the singleton instances. This is primarily for testing purposes. */ + public static void resetInstance() { + shutdown(); + } + + /** + * Checks if the base MongoClient has been initialized. + * + * @return true if initialized, false otherwise + */ + public static boolean isInitialized() { + return baseMongoClient != null; + } } diff --git a/src/main/java/listeners/DatabaseLifecycleListener.java b/src/main/java/listeners/DatabaseLifecycleListener.java new file mode 100644 index 000000000..fe9929f17 --- /dev/null +++ b/src/main/java/listeners/DatabaseLifecycleListener.java @@ -0,0 +1,84 @@ +package listeners; + +import dbProcs.ConnectionPool; +import dbProcs.MongoDatabase; +import javax.servlet.ServletContextEvent; +import javax.servlet.ServletContextListener; +import javax.servlet.annotation.WebListener; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * Servlet context listener for managing database connection pool lifecycle. Initializes connection + * pools when the application starts and shuts them down when the application stops. + * + *

This file is part of the Security Shepherd Project. + * + *

The Security Shepherd project is free software: you can redistribute it and/or modify it under + * the terms of the GNU General Public License as published by the Free Software Foundation, either + * version 3 of the License, or (at your option) any later version. + * + * @author Paul + */ +@WebListener +public class DatabaseLifecycleListener implements ServletContextListener { + + private static final Logger log = LogManager.getLogger(DatabaseLifecycleListener.class); + + /** + * Called when the servlet context is initialized (application startup). Initializes the database + * connection pools. + * + * @param sce The servlet context event + */ + @Override + public void contextInitialized(ServletContextEvent sce) { + log.info("Application starting - initializing database connection pools..."); + + try { + // Initialize MySQL/MariaDB connection pool + ConnectionPool.initialize(); + log.info("MySQL/MariaDB connection pool initialized successfully"); + + // Note: MongoDB connections are lazy-initialized on first use + // since MongoClient has its own internal connection pool + log.info("MongoDB connections will be initialized on first use"); + + } catch (RuntimeException e) { + // Catch RuntimeException (including configuration errors) to prevent + // application startup from failing completely + log.error("Failed to initialize database connection pools: " + e.getMessage(), e); + log.warn( + "Application will continue without database connectivity. " + + "Some features may not work until database is properly configured."); + } catch (Exception e) { + log.error("Failed to initialize database connection pools: " + e.getMessage(), e); + // Don't prevent application startup, but log the error + // Some modules may still work without database connectivity + } + } + + /** + * Called when the servlet context is destroyed (application shutdown). Shuts down all database + * connection pools to release resources. + * + * @param sce The servlet context event + */ + @Override + public void contextDestroyed(ServletContextEvent sce) { + log.info("Application shutting down - closing database connection pools..."); + + try { + // Shutdown MySQL/MariaDB connection pool + ConnectionPool.shutdown(); + log.info("MySQL/MariaDB connection pool shut down successfully"); + + // Shutdown MongoDB connections + MongoDatabase.shutdown(); + log.info("MongoDB connections shut down successfully"); + + } catch (Exception e) { + log.error("Error during database connection pool shutdown: " + e.getMessage(), e); + } + } +} diff --git a/src/main/java/servlets/Setup.java b/src/main/java/servlets/Setup.java index 76a305986..c56481e14 100644 --- a/src/main/java/servlets/Setup.java +++ b/src/main/java/servlets/Setup.java @@ -79,14 +79,14 @@ public void doPost(HttpServletRequest request, HttpServletResponse response) if (hasDBFile) { // Db auth file exists, try to load from it - if (!dbHost.isEmpty() || !dbPort.isEmpty()) { - // One of db host and db port are missing, we can't handle this situation! + if (dbHost.isEmpty() != dbPort.isEmpty()) { + // Only one of db host and db port provided, we need both or neither htmlOutput += "If you override db host and db port, both must be entered!"; validateInput = false; connectionURL = ""; - } else if (dbHost.isEmpty() || dbPort.isEmpty()) { - // Both db host and db port are missing, good, load from props file instead + } else if (dbHost.isEmpty() && dbPort.isEmpty()) { + // Both db host and db port are missing, load from props file instead connectionURL = mysql_props.getProperty("databaseConnectionURL"); String databaseSchema = mysql_props.getProperty("databaseSchema"); @@ -353,6 +353,19 @@ public void doPost(HttpServletRequest request, HttpServletResponse response) out.close(); } + /** + * Validates that db host and port are either both provided or both empty. Returns null if valid, + * or an error message if invalid. + */ + static String validateHostPort(String dbHost, String dbPort) { + if (dbHost == null) dbHost = ""; + if (dbPort == null) dbPort = ""; + if (dbHost.isEmpty() != dbPort.isEmpty()) { + return "If you override db host and db port, both must be entered!"; + } + return null; + } + public static boolean isInstalled() { boolean isInstalled = false; diff --git a/src/main/resources/database.properties.example b/src/main/resources/database.properties.example new file mode 100644 index 000000000..1af699365 --- /dev/null +++ b/src/main/resources/database.properties.example @@ -0,0 +1,39 @@ +# Security Shepherd Database Configuration +# This file should be placed at: ${catalina.base}/conf/database.properties + +# Database Connection Settings +databaseConnectionURL=jdbc:mysql://localhost:3306/ +databaseSchema=core +DriverType=org.gjt.mm.mysql.Driver +databaseUsername=root +databasePassword=yourpassword +databaseOptions=useUnicode=true&character_set_server=utf8mb4 + +# Connection Pool Settings (HikariCP) +# These settings control the database connection pool behavior. +# Adjust based on your expected concurrent user load. + +# Maximum number of connections in the pool +# Default: 10 +pool.maximumPoolSize=10 + +# Minimum number of idle connections to maintain +# Default: 2 +pool.minimumIdle=2 + +# Maximum time (ms) to wait for a connection from the pool +# Default: 30000 (30 seconds) +pool.connectionTimeout=30000 + +# Maximum time (ms) a connection can sit idle in the pool +# Default: 600000 (10 minutes) +pool.idleTimeout=600000 + +# Maximum lifetime (ms) of a connection in the pool +# Default: 1800000 (30 minutes) +pool.maxLifetime=1800000 + +# Leak detection threshold (ms) - logs a warning if a connection is held longer than this +# Helps identify code that forgets to close connections +# Default: 60000 (60 seconds), set to 0 to disable +pool.leakDetectionThreshold=60000 diff --git a/src/main/resources/mongo.properties.example b/src/main/resources/mongo.properties.example new file mode 100644 index 000000000..89685389b --- /dev/null +++ b/src/main/resources/mongo.properties.example @@ -0,0 +1,23 @@ +# Security Shepherd MongoDB Configuration +# This file should be placed at: ${catalina.base}/conf/mongo.properties + +# MongoDB Connection Settings +connectionHost=localhost +connectionPort=27017 +databaseName=shepherdGames + +# Connection Timeouts (in milliseconds) +connectTimeout=10000 +socketTimeout=0 +serverSelectionTimeout=30000 + +# Connection Pool Settings +# These settings control the MongoDB connection pool behavior. + +# Maximum number of connections per host +# Default: 10 +pool.connectionsPerHost=10 + +# Minimum number of connections per host +# Default: 2 +pool.minConnectionsPerHost=2 diff --git a/src/main/webapp/WEB-INF/web.xml b/src/main/webapp/WEB-INF/web.xml index 0a4126df4..d26286ad4 100644 --- a/src/main/webapp/WEB-INF/web.xml +++ b/src/main/webapp/WEB-INF/web.xml @@ -41,6 +41,11 @@ true + + + listeners.DatabaseLifecycleListener + + /setup servlets.Setup diff --git a/src/test/java/listeners/DatabaseLifecycleListenerTest.java b/src/test/java/listeners/DatabaseLifecycleListenerTest.java new file mode 100644 index 000000000..dde5778d6 --- /dev/null +++ b/src/test/java/listeners/DatabaseLifecycleListenerTest.java @@ -0,0 +1,119 @@ +package listeners; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import dbProcs.ConnectionPool; +import dbProcs.MongoDatabase; +import java.io.IOException; +import java.sql.Connection; +import javax.servlet.ServletContextEvent; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import testUtils.TestProperties; + +/** + * Unit tests for the DatabaseLifecycleListener class. + * + *

Note: Tests that require actual database connectivity will be skipped if the database is not + * available. Run with a proper database setup for full coverage. + */ +public class DatabaseLifecycleListenerTest { + + private static final Logger log = LogManager.getLogger(DatabaseLifecycleListenerTest.class); + private static boolean databaseAvailable = false; + + @Mock private ServletContextEvent mockServletContextEvent; + + private DatabaseLifecycleListener listener; + + @BeforeAll + public static void setupClass() throws IOException { + TestProperties.setTestPropertiesFileDirectory(log); + TestProperties.createMysqlResource(); + TestProperties.createMongoResource(); + + try { + ConnectionPool.initialize(); + Connection conn = ConnectionPool.getConnection(); + conn.close(); + databaseAvailable = true; + log.info("Database is available - running full test suite"); + } catch (Exception e) { + databaseAvailable = false; + log.warn("Database not available - skipping connection-dependent tests: " + e.getMessage()); + } finally { + ConnectionPool.reset(); + } + } + + @BeforeEach + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + ConnectionPool.reset(); + MongoDatabase.resetInstance(); + TestProperties.createMysqlResource(); + TestProperties.createMongoResource(); + listener = new DatabaseLifecycleListener(); + } + + private void requireDatabase() { + assumeTrue(databaseAvailable, "Database not available"); + } + + @Test + @DisplayName("contextInitialized should initialize the connection pool") + public void testContextInitialized() { + requireDatabase(); + + assertFalse( + ConnectionPool.isInitialized(), "Pool should not be initialized before contextInitialized"); + listener.contextInitialized(mockServletContextEvent); + assertTrue( + ConnectionPool.isInitialized(), "Pool should be initialized after contextInitialized"); + } + + @Test + @DisplayName("contextDestroyed should shut down all pools") + public void testContextDestroyed() { + requireDatabase(); + + listener.contextInitialized(mockServletContextEvent); + assertTrue(ConnectionPool.isInitialized(), "Pool should be initialized"); + + listener.contextDestroyed(mockServletContextEvent); + assertFalse( + ConnectionPool.isInitialized(), "Pool should not be initialized after contextDestroyed"); + assertFalse( + MongoDatabase.isInitialized(), "MongoDB should not be initialized after contextDestroyed"); + } + + @Test + @DisplayName("contextInitialized should handle missing config gracefully") + public void testInitializationFailureHandling() throws IOException { + TestProperties.deleteMysqlResource(); + ConnectionPool.reset(); + + listener.contextInitialized(mockServletContextEvent); + + assertFalse( + ConnectionPool.isInitialized(), "Pool should not be initialized when config is missing"); + + TestProperties.createMysqlResource(); + } + + @Test + @DisplayName("contextDestroyed without prior init should not throw") + public void testContextDestroyedWithoutInit() { + assertFalse(ConnectionPool.isInitialized(), "Pool should not be initialized"); + listener.contextDestroyed(mockServletContextEvent); + assertFalse(ConnectionPool.isInitialized(), "Pool should still not be initialized"); + } +} diff --git a/src/test/java/servlets/SetupTest.java b/src/test/java/servlets/SetupTest.java new file mode 100644 index 000000000..5e6f3acc8 --- /dev/null +++ b/src/test/java/servlets/SetupTest.java @@ -0,0 +1,44 @@ +package servlets; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +import org.junit.jupiter.api.Test; + +public class SetupTest { + + @Test + public void validateHostPort_bothEmpty_isValid() { + assertNull(Setup.validateHostPort("", "")); + } + + @Test + public void validateHostPort_bothProvided_isValid() { + assertNull(Setup.validateHostPort("localhost", "3306")); + } + + @Test + public void validateHostPort_onlyHostProvided_isInvalid() { + assertNotNull(Setup.validateHostPort("localhost", "")); + } + + @Test + public void validateHostPort_onlyPortProvided_isInvalid() { + assertNotNull(Setup.validateHostPort("", "3306")); + } + + @Test + public void validateHostPort_bothNull_isValid() { + assertNull(Setup.validateHostPort(null, null)); + } + + @Test + public void validateHostPort_hostNullPortProvided_isInvalid() { + assertNotNull(Setup.validateHostPort(null, "3306")); + } + + @Test + public void validateHostPort_hostProvidedPortNull_isInvalid() { + assertNotNull(Setup.validateHostPort("localhost", null)); + } +} diff --git a/src/test/java/testUtils/TestProperties.java b/src/test/java/testUtils/TestProperties.java index 7942afa41..a03ccf880 100644 --- a/src/test/java/testUtils/TestProperties.java +++ b/src/test/java/testUtils/TestProperties.java @@ -2,9 +2,11 @@ import static org.junit.Assert.fail; +import dbProcs.ConnectionPool; import dbProcs.Constants; import dbProcs.Database; import dbProcs.Getter; +import dbProcs.MongoDatabase; import dbProcs.Setter; import io.github.cdimascio.dotenv.Dotenv; import java.io.BufferedWriter; @@ -488,4 +490,44 @@ public static void createMongoResource() throws IOException { public static void deleteMongoResource() { FileUtils.deleteQuietly(new File(Constants.MONGO_DB_PROP)); } + + /** + * Initialize the MySQL/MariaDB connection pool for tests. Call this in @BeforeClass after + * createMysqlResource(). + */ + public static void initializeConnectionPool() { + ConnectionPool.initialize(); + log.debug("Connection pool initialized for tests"); + } + + /** + * Shutdown the MySQL/MariaDB connection pool after tests. Call this in @AfterClass to clean up + * resources. + */ + public static void shutdownConnectionPool() { + ConnectionPool.shutdown(); + log.debug("Connection pool shut down after tests"); + } + + /** Reset the MySQL/MariaDB connection pool. Useful for tests that need a fresh pool state. */ + public static void resetConnectionPool() { + ConnectionPool.reset(); + log.debug("Connection pool reset"); + } + + /** Reset the MongoDB singleton instance. Call this to ensure a clean state for MongoDB tests. */ + public static void resetMongoSingleton() { + MongoDatabase.resetInstance(); + log.debug("MongoDB singleton reset"); + } + + /** + * Shutdown all database connections (MySQL and MongoDB). Call this in @AfterClass to ensure all + * resources are released. + */ + public static void shutdownAllDatabases() { + shutdownConnectionPool(); + resetMongoSingleton(); + log.debug("All database connections shut down"); + } } diff --git a/src/test/resources/challenges/mongo_challenge_test.properties b/src/test/resources/challenges/mongo_challenge_test.properties new file mode 100644 index 000000000..58604d9d4 --- /dev/null +++ b/src/test/resources/challenges/mongo_challenge_test.properties @@ -0,0 +1,4 @@ +databaseName=test_dbname +databaseUsername=test_user +databasePassword=test_password +databaseCollection=test_collection diff --git a/tests/load/.gitignore b/tests/load/.gitignore new file mode 100644 index 000000000..fbca22537 --- /dev/null +++ b/tests/load/.gitignore @@ -0,0 +1 @@ +results/ diff --git a/tests/load/README.md b/tests/load/README.md new file mode 100644 index 000000000..3f908670d --- /dev/null +++ b/tests/load/README.md @@ -0,0 +1,42 @@ +# Load Tests + +Tests to verify Security Shepherd survives aggressive automated scanning tools without database connection exhaustion. See [issue #536](https://github.com/OWASP/SecurityShepherd/issues/536). + +## load-test.sh + +End-to-end load test that: + +1. Builds and starts the full Docker stack +2. Runs the initial database setup automatically +3. Creates 20 test users +4. Simulates 17 normal users (browsing every 3-8s) and 3 aggressive users (automated scanning, ~10 req/s each) +5. Monitors DB connections and app response times throughout +6. Reports pass/fail based on: + - DB connections stay under 50 (pooling prevents exhaustion) + - No health check failures (app stays responsive) + - Response times stay under 10 seconds + +### Usage + +```bash +# Full run (builds everything from scratch) +./load-test.sh + +# Skip the Maven/Docker build if stack images are already current +./load-test.sh --skip-build +``` + +### Requirements + +- Docker +- curl +- Maven (unless using --skip-build) +- Ports 80, 443, 3306, 27017 available + +### Results + +Each run creates a timestamped directory under `results/` with: + +- `monitor.csv` — time-series of DB connections, HTTP status, and response time +- `cookies_*.txt` — session cookies (auto-cleaned) +- `aggressive_*_requests.txt` — request counts per aggressive user diff --git a/tests/load/load-test.py b/tests/load/load-test.py new file mode 100755 index 000000000..bcbad64fe --- /dev/null +++ b/tests/load/load-test.py @@ -0,0 +1,602 @@ +#!/usr/bin/env python3 +""" +Load test for Security Shepherd connection pooling (issue #536). + +Simulates 20 users: 17 doing normal browsing, 3 running aggressive +automated scanning. Monitors DB connections and app responsiveness +to verify the connection pool prevents exhaustion. + +Usage: + python3 load-test.py [--skip-build] [--duration MINUTES] [--users NORMAL AGGRESSIVE] +""" + +import argparse +import csv +import http.cookiejar +import os +import re +import ssl +import subprocess +import sys +import threading +import time +import urllib.error +import urllib.parse +import urllib.request +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime +from pathlib import Path + +# Disable SSL verification globally (self-signed cert) +SSL_CTX = ssl.create_default_context() +SSL_CTX.check_hostname = False +SSL_CTX.verify_mode = ssl.CERT_NONE + +BASE_URL = "https://localhost" +DB_CONTAINER = "secshep_mariadb" +TOMCAT_CONTAINER = "secshep_tomcat" +DB_PASS = "CowSaysMoo" +ADMIN_USER = "admin" +ADMIN_DEFAULT_PASS = "password" +ADMIN_NEW_PASS = "LoadTestAdmin1" + +SPIDER_PATHS = [ + "/login.jsp", "/register.jsp", "/index.jsp", "/logout", + "/admin/", "/challenges/", "/lessons/", "/setup.jsp", + "/css/theCss.css", "/js/jquery.js", "/login", "/register", + "/passwordChange", "/usernameChange", "/refreshMenu", + "/mobileLogin", "/setup", "/getModule", "/getCheat", + "/feedbackSubmit", "/solutionSubmit", +] + +NORMAL_PAGES = ["/login.jsp", "/index.jsp", "/register.jsp", "/logout"] + + +# ── Utilities ────────────────────────────────────────────────────── + + +def log(msg): + print(f"\033[0;32m[+]\033[0m {msg}", flush=True) + + +def warn(msg): + print(f"\033[1;33m[!]\033[0m {msg}", flush=True) + + +def fail(msg): + print(f"\033[0;31m[-]\033[0m {msg}", flush=True) + sys.exit(1) + + +def docker_exec(container, cmd): + """Run a command in a Docker container and return stdout.""" + result = subprocess.run( + ["docker", "exec", container] + cmd, + capture_output=True, text=True, timeout=30 + ) + return result.stdout.strip(), result.returncode + + +def docker_compose(*args): + """Run docker compose command.""" + result = subprocess.run( + ["docker", "compose"] + list(args), + capture_output=True, text=True, timeout=300 + ) + return result.stdout, result.stderr, result.returncode + + +def db_query(sql): + """Execute a SQL query against MariaDB and return stdout.""" + stdout, rc = docker_exec(DB_CONTAINER, [ + "mariadb", f"-uroot", f"-p{DB_PASS}", "-sN", "-e", sql + ]) + return stdout if rc == 0 else None + + +def get_connections(): + """Get current DB thread count.""" + result = db_query("SHOW STATUS LIKE 'Threads_connected';") + if result: + parts = result.split() + if len(parts) >= 2: + return int(parts[1]) + return None + + +class ShepherdSession: + """HTTP session with cookie handling for Security Shepherd.""" + + def __init__(self): + self.cookie_jar = http.cookiejar.CookieJar() + self.opener = urllib.request.build_opener( + urllib.request.HTTPCookieProcessor(self.cookie_jar), + urllib.request.HTTPSHandler(context=SSL_CTX), + ) + + def get(self, path, follow_redirects=True): + """GET request, returns (status_code, body, headers).""" + url = BASE_URL + path + try: + req = urllib.request.Request(url) + resp = self.opener.open(req, timeout=10) + return resp.status, resp.read().decode("utf-8", errors="replace"), dict(resp.headers) + except urllib.error.HTTPError as e: + return e.code, e.read().decode("utf-8", errors="replace"), dict(e.headers) + except Exception: + return 0, "", {} + + def post(self, path, data, follow_redirects=False): + """POST request, returns (status_code, body, location_header).""" + url = BASE_URL + path + encoded = urllib.parse.urlencode(data).encode("utf-8") + req = urllib.request.Request(url, data=encoded, method="POST") + req.add_header("Content-Type", "application/x-www-form-urlencoded") + try: + resp = self.opener.open(req, timeout=10) + return resp.status, resp.read().decode("utf-8", errors="replace"), resp.url + except urllib.error.HTTPError as e: + location = e.headers.get("Location", "") + return e.code, e.read().decode("utf-8", errors="replace"), location + except Exception: + return 0, "", "" + + @property + def token(self): + """Get the CSRF token cookie value.""" + for cookie in self.cookie_jar: + if cookie.name == "token": + return cookie.value + return None + + def get_csrf_from_page(self, path): + """Extract CSRF token from page HTML.""" + _, body, _ = self.get(path) + match = re.search(r'csrfToken:\s*"([^"]+)"', body) + return match.group(1) if match else None + + +# ── Setup Steps ──────────────────────────────────────────────────── + + +def build_and_start(skip_build, project_root): + """Build and start the Docker stack.""" + os.chdir(project_root) + + if not skip_build: + log("Building WAR and Docker artifacts...") + result = subprocess.run( + ["mvn", "-Pdocker", "clean", "install", "-DskipTests", "-B", "-q"], + timeout=300 + ) + if result.returncode != 0: + fail("Maven build failed") + subprocess.run(["docker", "compose", "build", "--no-cache", "-q"], timeout=300) + else: + log("Skipping build (--skip-build)") + + # Check for existing volumes + stdout, _, _ = docker_compose("down", "-v") + # Remove any orphaned volumes + result = subprocess.run( + ["docker", "volume", "ls", "-q", "--filter", "name=securityshepherd"], + capture_output=True, text=True + ) + for vol in result.stdout.strip().split("\n"): + if vol: + subprocess.run(["docker", "volume", "rm", vol], capture_output=True) + + log("Starting stack...") + stdout, stderr, rc = docker_compose("up", "-d") + + +def wait_for_services(): + """Wait for MariaDB and Tomcat to be ready.""" + log("Waiting for MariaDB...") + for i in range(60): + _, rc = docker_exec(DB_CONTAINER, ["mariadb", f"-uroot", f"-p{DB_PASS}", "-e", "SELECT 1"]) + if rc == 0: + break + time.sleep(2) + else: + fail("MariaDB did not start in time") + log("MariaDB ready") + + log("Waiting for Tomcat...") + for i in range(60): + try: + req = urllib.request.Request(BASE_URL) + urllib.request.urlopen(req, timeout=3, context=SSL_CTX) + break + except Exception: + time.sleep(2) + else: + fail("Tomcat did not start in time") + log("Tomcat ready") + + +def configure_platform(): + """Login as admin, change password, and enable registration.""" + session = ShepherdSession() + + # Get initial session + log("Logging in as admin (admin/password)...") + session.get("/login.jsp") + + # Login + status, body, location = session.post("/login", { + "login": ADMIN_USER, + "pwd": ADMIN_DEFAULT_PASS, + }) + + # Change temporary password + log("Changing admin password...") + token = session.token + if not token: + fail("No CSRF token cookie after admin login") + + session.post("/passwordChange", { + "currentPassword": ADMIN_DEFAULT_PASS, + "newPassword": ADMIN_NEW_PASS, + "passwordConfirmation": ADMIN_NEW_PASS, + "csrfToken": token, + }) + log("Admin password changed") + + # Enable registration + log("Enabling registration...") + token = session.token + status, body, _ = session.post("/updateRegistration", {"csrfToken": token}) + + if "Opened" in body: + log("Registration enabled") + elif "Closed" in body: + # Was already open, got toggled closed — toggle again + session.post("/updateRegistration", {"csrfToken": token}) + log("Registration enabled (toggled twice)") + else: + warn(f"Toggle response: {body[:200]}") + fail("Could not enable registration") + + return session + + +def register_users(num_users): + """Register test users via the web UI.""" + log(f"Registering {num_users} test users...") + created = 0 + + for i in range(1, num_users + 1): + username = f"loadtest_user_{i}" + session = ShepherdSession() + + csrf = session.get_csrf_from_page("/register.jsp") + if not csrf: + warn(f"Could not get CSRF token for {username}") + continue + + status, body, location = session.post("/register", { + "userName": username, + "passWord": username, + "passWordConfirm": username, + "userAddress": f"{username}@test.com", + "userAddressCnf": f"{username}@test.com", + "csrfToken": csrf, + }) + + if status == 302 or "login.jsp" in str(location): + created += 1 + else: + warn(f"Failed to register {username} (HTTP {status})") + + if created == 0: + fail("No users were registered") + + log(f"{created} / {num_users} users registered") + return created + + +def login_users(num_users): + """Login all test users and return their sessions.""" + log("Logging in test users...") + sessions = {} + logged_in = 0 + + for i in range(1, num_users + 1): + username = f"loadtest_user_{i}" + session = ShepherdSession() + session.get("/login.jsp") + + status, body, location = session.post("/login", { + "login": username, + "pwd": username, + }) + + if "index.jsp" in str(location): + sessions[i] = session + logged_in += 1 + else: + warn(f"Login failed for {username} (HTTP {status}, location: {location})") + + if logged_in == 0: + fail("No users could log in") + + log(f"{logged_in} / {num_users} users logged in") + return sessions + + +# ── Traffic Simulation ───────────────────────────────────────────── + + +def normal_user_traffic(session, duration): + """Simulate a normal user browsing every 3-8 seconds.""" + import random + end_time = time.time() + duration + requests_made = 0 + + while time.time() < end_time: + page = random.choice(NORMAL_PAGES) + try: + session.get(page) + requests_made += 1 + except Exception: + pass + time.sleep(random.uniform(3, 8)) + + return requests_made + + +def aggressive_user_traffic(session, duration): + """Simulate aggressive automated scanning (~10 req/s).""" + import random + end_time = time.time() + duration + requests_made = 0 + + while time.time() < end_time: + path = random.choice(SPIDER_PATHS) + + # GET (spider) + try: + session.get(path) + requests_made += 1 + except Exception: + pass + + # POST with random params (fuzzer) + try: + session.post(path, { + "param1": "test", + "param2": os.urandom(8).hex(), + }) + requests_made += 1 + except Exception: + pass + + time.sleep(random.uniform(0, 0.2)) + + return requests_made + + +# ── Monitoring ───────────────────────────────────────────────────── + + +def monitor_loop(duration, interval, results_file, stop_event): + """Monitor DB connections and app responsiveness.""" + with open(results_file, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["timestamp", "db_connections", "http_status", "response_time_ms"]) + + while not stop_event.is_set(): + ts = datetime.now().strftime("%H:%M:%S") + conns = get_connections() + + # Health check + start = time.time() + try: + req = urllib.request.Request(BASE_URL + "/login.jsp") + resp = urllib.request.urlopen(req, timeout=10, context=SSL_CTX) + http_status = resp.status + except urllib.error.HTTPError as e: + http_status = e.code + except Exception: + http_status = 0 + response_ms = int((time.time() - start) * 1000) + + writer.writerow([ts, conns or "N/A", http_status, response_ms]) + f.flush() + + conns_str = str(conns) if conns else "N/A" + print( + f" {ts} | Connections: {conns_str:<4} | HTTP: {http_status} | Response: {response_ms}ms", + flush=True, + ) + + stop_event.wait(interval) + + +# ── Report ───────────────────────────────────────────────────────── + + +def generate_report(results_file, config, aggressive_total): + """Parse monitor CSV and print results.""" + connections = [] + responses = [] + failed = 0 + total = 0 + + with open(results_file) as f: + reader = csv.DictReader(f) + for row in reader: + total += 1 + try: + conns = int(row["db_connections"]) + connections.append(conns) + except (ValueError, KeyError): + pass + try: + resp = int(row["response_time_ms"]) + responses.append(resp) + except (ValueError, KeyError): + pass + try: + status = int(row["http_status"]) + if status == 0 or status >= 500: + failed += 1 + except (ValueError, KeyError): + failed += 1 + + if not connections or not responses: + fail("No monitoring data collected") + + max_conns = max(connections) + min_conns = min(connections) + avg_conns = sum(connections) // len(connections) + max_resp = max(responses) + avg_resp = sum(responses) // len(responses) + + print() + print("=" * 59) + print(" LOAD TEST RESULTS") + print("=" * 59) + print() + print(" Configuration:") + print(f" Normal users: {config['normal']} (request every 3-8s)") + print(f" Aggressive users: {config['aggressive']} (automated scanning, ~10 req/s each)") + print(f" Duration: {config['duration'] // 60} minutes") + print() + print(" Database Connections:") + print(f" Baseline: {config['baseline']}") + print(f" Min: {min_conns}") + print(f" Max: {max_conns}") + print(f" Average: {avg_conns}") + print() + print(" App Responsiveness:") + print(f" Avg response: {avg_resp}ms") + print(f" Max response: {max_resp}ms") + print(f" Failed checks: {failed} / {total}") + print() + print(" Aggressive Traffic:") + print(f" Total requests: {aggressive_total}") + print() + + passed = True + if max_conns > 50: + print(f" \033[0;31mFAIL: Max connections ({max_conns}) exceeded 50\033[0m") + passed = False + if failed > 0: + print(f" \033[0;31mFAIL: {failed} health checks failed\033[0m") + passed = False + if max_resp > 10000: + print(f" \033[0;31mFAIL: Max response time ({max_resp}ms) exceeded 10s\033[0m") + passed = False + if passed: + print(" \033[0;32mPASS: Connection pool held under load\033[0m") + + print() + print(f" Full results: {results_file}") + print("=" * 59) + + return passed + + +# ── Main ─────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser(description="Security Shepherd load test") + parser.add_argument("--skip-build", action="store_true", help="Skip Maven/Docker build") + parser.add_argument("--duration", type=int, default=5, help="Test duration in minutes (default: 5)") + parser.add_argument("--normal-users", type=int, default=17, help="Number of normal users (default: 17)") + parser.add_argument("--aggressive-users", type=int, default=3, help="Number of aggressive users (default: 3)") + parser.add_argument("--monitor-interval", type=int, default=10, help="Monitor interval in seconds (default: 10)") + args = parser.parse_args() + + duration = args.duration * 60 + total_users = args.normal_users + args.aggressive_users + + script_dir = Path(__file__).resolve().parent + project_root = script_dir.parent.parent + results_dir = script_dir / "results" / datetime.now().strftime("%Y%m%d-%H%M%S") + results_dir.mkdir(parents=True, exist_ok=True) + monitor_file = str(results_dir / "monitor.csv") + + # Step 1: Build and start + build_and_start(args.skip_build, str(project_root)) + + # Step 2: Wait for services + wait_for_services() + + # Step 3: Configure platform (admin login, password change, enable registration) + configure_platform() + + # Step 4: Register users + register_users(total_users) + + # Step 5: Login users + sessions = login_users(total_users) + + # Step 6: Record baseline + baseline = get_connections() or 0 + log(f"Baseline DB connections: {baseline}") + + # Step 7: Start monitoring + log("Starting monitor...") + stop_monitor = threading.Event() + monitor_thread = threading.Thread( + target=monitor_loop, + args=(duration, args.monitor_interval, monitor_file, stop_monitor), + daemon=True, + ) + monitor_thread.start() + + # Step 8: Start traffic + log(f"Starting {args.normal_users} normal + {args.aggressive_users} aggressive users for {args.duration} minutes...") + print() + + with ThreadPoolExecutor(max_workers=total_users) as executor: + futures = {} + + # Normal users + for i in range(1, args.normal_users + 1): + if i in sessions: + f = executor.submit(normal_user_traffic, sessions[i], duration) + futures[f] = ("normal", i) + + # Aggressive users + for i in range(args.normal_users + 1, total_users + 1): + if i in sessions: + f = executor.submit(aggressive_user_traffic, sessions[i], duration) + futures[f] = ("aggressive", i) + + # Wait for all to complete + aggressive_total = 0 + for future in as_completed(futures): + kind, user_id = futures[future] + try: + count = future.result() + if kind == "aggressive": + aggressive_total += count + except Exception as e: + warn(f"User {user_id} ({kind}) error: {e}") + + # Step 9: Stop monitoring and report + time.sleep(5) + stop_monitor.set() + monitor_thread.join(timeout=10) + + print() + log("Load test complete. Analyzing results...") + + config = { + "normal": args.normal_users, + "aggressive": args.aggressive_users, + "duration": duration, + "baseline": baseline, + } + + passed = generate_report(monitor_file, config, aggressive_total) + sys.exit(0 if passed else 1) + + +if __name__ == "__main__": + main()