From 847b91bf1fdf466329ef3b5755d3c6df4f9bc946 Mon Sep 17 00:00:00 2001 From: Aaron Schulz Date: Fri, 16 Sep 2016 21:39:57 -0700 Subject: [PATCH] Make database classes handle hyphens in $wgDBname * Add DatabaseDomain class to handle passing domains around. It also can be cast to and from strings, which are of the same format as wfWikiId() except with hyphens escaped. * Make IDatabase::getDomainID() use these IDs so they can be passed into LoadBalancer::getConnection() and friends without breaking on sites with a hyphen in the DB name. * Add more LBFactory unit tests for domains. Bug: T145840 Change-Id: Icfed62b251af8cef706a899197c3ccdb730ef4d1 --- autoload.php | 1 + includes/db/loadbalancer/LBFactoryMW.php | 4 +- includes/libs/rdbms/database/DBConnRef.php | 13 +- includes/libs/rdbms/database/Database.php | 15 +- .../libs/rdbms/database/DatabaseDomain.php | 203 ++++++++++++++++++ includes/libs/rdbms/lbfactory/LBFactory.php | 13 +- .../libs/rdbms/loadbalancer/LoadBalancer.php | 33 ++- tests/phpunit/includes/db/LBFactoryTest.php | 142 ++++++++++++ .../rdbms/database/DatabaseDomainTest.php | 69 ++++++ 9 files changed, 469 insertions(+), 24 deletions(-) create mode 100644 includes/libs/rdbms/database/DatabaseDomain.php create mode 100644 tests/phpunit/includes/libs/rdbms/database/DatabaseDomainTest.php diff --git a/autoload.php b/autoload.php index 716e56db87..035c15265f 100644 --- a/autoload.php +++ b/autoload.php @@ -318,6 +318,7 @@ $wgAutoloadLocalClasses = [ 'DataUpdate' => __DIR__ . '/includes/deferred/DataUpdate.php', 'Database' => __DIR__ . '/includes/libs/rdbms/database/Database.php', 'DatabaseBase' => __DIR__ . '/includes/libs/rdbms/database/DatabaseBase.php', + 'DatabaseDomain' => __DIR__ . '/includes/libs/rdbms/database/DatabaseDomain.php', 'DatabaseInstaller' => __DIR__ . '/includes/installer/DatabaseInstaller.php', 'DatabaseLag' => __DIR__ . '/maintenance/lag.php', 'DatabaseLogEntry' => __DIR__ . '/includes/logging/LogEntry.php', diff --git a/includes/db/loadbalancer/LBFactoryMW.php b/includes/db/loadbalancer/LBFactoryMW.php index 49d0624710..16faeb7390 100644 --- a/includes/db/loadbalancer/LBFactoryMW.php +++ b/includes/db/loadbalancer/LBFactoryMW.php @@ -37,10 +37,10 @@ abstract class LBFactoryMW extends LBFactory implements DestructibleService { * @TODO: inject objects via dependency framework */ public function __construct( array $conf ) { - global $wgCommandLineMode, $wgSQLMode, $wgDBmysql5; + global $wgCommandLineMode, $wgSQLMode, $wgDBmysql5, $wgDBname, $wgDBprefix; $defaults = [ - 'localDomain' => wfWikiID(), + 'localDomain' => new DatabaseDomain( $wgDBname, null, $wgDBprefix ), 'hostname' => wfHostname(), 'trxProfiler' => Profiler::instance()->getTransactionProfiler(), 'replLogger' => LoggerFactory::getInstance( 'DBReplication' ), diff --git a/includes/libs/rdbms/database/DBConnRef.php b/includes/libs/rdbms/database/DBConnRef.php index 876ee30e30..0d9b692816 100644 --- a/includes/libs/rdbms/database/DBConnRef.php +++ b/includes/libs/rdbms/database/DBConnRef.php @@ -14,22 +14,22 @@ class DBConnRef implements IDatabase { /** @var IDatabase|null Live connection handle */ private $conn; - /** @var array|null */ + /** @var array|null N-tuple of (server index, group, DatabaseDomain|string) */ private $params; const FLD_INDEX = 0; const FLD_GROUP = 1; - const FLD_WIKI = 2; + const FLD_DOMAIN = 2; /** * @param ILoadBalancer $lb - * @param IDatabase|array $conn Connection or (server index, group, wiki ID) + * @param IDatabase|array $conn Connection or (server index, group, DatabaseDomain|string) */ public function __construct( ILoadBalancer $lb, $conn ) { $this->lb = $lb; if ( $conn instanceof IDatabase ) { $this->conn = $conn; // live handle - } elseif ( count( $conn ) >= 3 && $conn[self::FLD_WIKI] !== false ) { + } elseif ( count( $conn ) >= 3 && $conn[self::FLD_DOMAIN] !== false ) { $this->params = $conn; } else { throw new InvalidArgumentException( "Missing lazy connection arguments." ); @@ -147,8 +147,9 @@ class DBConnRef implements IDatabase { public function getDomainID() { if ( $this->conn === null ) { - // Avoid triggering a connection - return $this->params[self::FLD_WIKI]; + $domain = $this->params[self::FLD_DOMAIN]; + // Avoid triggering a database connection + return $domain instanceof DatabaseDomain ? $domain->getId() : $domain; } return $this->__call( __FUNCTION__, func_get_args() ); diff --git a/includes/libs/rdbms/database/Database.php b/includes/libs/rdbms/database/Database.php index 9a63b7fced..4cab119234 100644 --- a/includes/libs/rdbms/database/Database.php +++ b/includes/libs/rdbms/database/Database.php @@ -112,6 +112,8 @@ abstract class Database implements IDatabase, LoggerAwareInterface { protected $htmlErrors; /** @var string */ protected $delimiter = ';'; + /** @var DatabaseDomain */ + protected $currentDomain; /** * Either 1 if a transaction is active or 0 otherwise. @@ -288,6 +290,10 @@ abstract class Database implements IDatabase, LoggerAwareInterface { if ( $user ) { $this->open( $server, $user, $password, $dbName ); } + + $this->currentDomain = ( $this->mDBname != '' ) + ? new DatabaseDomain( $this->mDBname, null, $this->mTablePrefix ) + : DatabaseDomain::newUnspecified(); } /** @@ -442,6 +448,9 @@ abstract class Database implements IDatabase, LoggerAwareInterface { $old = $this->mTablePrefix; if ( $prefix !== null ) { $this->mTablePrefix = $prefix; + $this->currentDomain = ( $this->mDBname != '' ) + ? new DatabaseDomain( $this->mDBname, null, $this->mTablePrefix ) + : DatabaseDomain::newUnspecified(); } return $old; @@ -621,11 +630,7 @@ abstract class Database implements IDatabase, LoggerAwareInterface { } public function getDomainID() { - if ( $this->mTablePrefix != '' ) { - return "{$this->mDBname}-{$this->mTablePrefix}"; - } else { - return $this->mDBname; - } + return $this->currentDomain->getId(); } final public function getWikiID() { diff --git a/includes/libs/rdbms/database/DatabaseDomain.php b/includes/libs/rdbms/database/DatabaseDomain.php new file mode 100644 index 0000000000..01b6b21f04 --- /dev/null +++ b/includes/libs/rdbms/database/DatabaseDomain.php @@ -0,0 +1,203 @@ +database = $database; + if ( $schema !== null && ( !is_string( $schema ) || !strlen( $schema ) ) ) { + throw new InvalidArgumentException( "Schema must be null or a non-empty string." ); + } + $this->schema = $schema; + if ( !is_string( $prefix ) ) { + throw new InvalidArgumentException( "Prefix must be a string." ); + } + $this->prefix = $prefix; + $this->equivalentString = $this->convertToString(); + } + + /** + * @param DatabaseDomain|string $domain Result of DatabaseDomain::toString() + * @return DatabaseDomain + */ + public static function newFromId( $domain ) { + if ( $domain instanceof self ) { + return $domain; + } + + $parts = array_map( [ __CLASS__, 'decode' ], explode( '-', $domain ) ); + + $schema = null; + $prefix = ''; + + if ( count( $parts ) == 1 ) { + $database = $parts[0]; + } elseif ( count( $parts ) == 2 ) { + list( $database, $prefix ) = $parts; + } elseif ( count( $parts ) == 3 ) { + list( $database, $schema, $prefix ) = $parts; + } else { + throw new InvalidArgumentException( "Domain has too few or too many parts." ); + } + + if ( $database === '' ) { + $database = null; + } + + return new self( $database, $schema, $prefix ); + } + + /** + * @return DatabaseDomain + */ + public static function newUnspecified() { + return new self( null, null, '' ); + } + + /** + * @param DatabaseDomain|string $other + * @return bool + */ + public function equals( $other ) { + if ( $other instanceof DatabaseDomain ) { + return ( + $this->database === $other->database && + $this->schema === $other->schema && + $this->prefix === $other->prefix + ); + } + + return ( $this->equivalentString === $other ); + } + + /** + * @return string|null Database name + */ + public function getDatabase() { + return $this->database; + } + + /** + * @return string|null Database schema + */ + public function getSchema() { + return $this->schema; + } + + /** + * @return string Table prefix + */ + public function getTablePrefix() { + return $this->prefix; + } + + /** + * @return string + */ + public function getId() { + return $this->equivalentString; + } + + /** + * @return string + */ + private function convertToString() { + $parts = [ $this->database ]; + if ( $this->schema !== null ) { + $parts[] = $this->schema; + } + if ( $this->prefix != '' ) { + $parts[] = $this->prefix; + } + + return implode( '-', array_map( [ __CLASS__, 'encode' ], $parts ) ); + } + + private static function encode( $decoded ) { + $encoded = ''; + + $length = strlen( $decoded ); + for ( $i = 0; $i < $length; ++$i ) { + $char = $decoded[$i]; + if ( $char === '-' ) { + $encoded .= '?h'; + } elseif ( $char === '?' ) { + $encoded .= '??'; + } else { + $encoded .= $char; + } + } + + return $encoded; + } + + private static function decode( $encoded ) { + $decoded = ''; + + $length = strlen( $encoded ); + for ( $i = 0; $i < $length; ++$i ) { + $char = $encoded[$i]; + if ( $char === '?' ) { + $nextChar = isset( $encoded[$i + 1] ) ? $encoded[$i + 1] : null; + if ( $nextChar === 'h' ) { + $decoded .= '-'; + ++$i; + } elseif ( $nextChar === '?' ) { + $decoded .= '?'; + ++$i; + } else { + $decoded .= $char; + } + } else { + $decoded .= $char; + } + } + + return $decoded; + } + + /** + * @return string + */ + function __toString() { + return $this->getId(); + } +} diff --git a/includes/libs/rdbms/lbfactory/LBFactory.php b/includes/libs/rdbms/lbfactory/LBFactory.php index 49fac6afc2..3ab7362a61 100644 --- a/includes/libs/rdbms/lbfactory/LBFactory.php +++ b/includes/libs/rdbms/lbfactory/LBFactory.php @@ -49,7 +49,7 @@ abstract class LBFactory { /** @var WANObjectCache */ protected $wanCache; - /** @var string Local domain */ + /** @var DatabaseDomain Local domain */ protected $localDomain; /** @var string Local hostname of the app server */ protected $hostname; @@ -79,7 +79,9 @@ abstract class LBFactory { * @param array $conf */ public function __construct( array $conf ) { - $this->localDomain = isset( $conf['localDomain'] ) ? $conf['localDomain'] : ''; + $this->localDomain = isset( $conf['localDomain'] ) + ? DatabaseDomain::newFromId( $conf['localDomain'] ) + : DatabaseDomain::newUnspecified(); if ( isset( $conf['readOnlyReason'] ) && is_string( $conf['readOnlyReason'] ) ) { $this->readOnlyReason = $conf['readOnlyReason']; @@ -638,8 +640,11 @@ abstract class LBFactory { * @since 1.28 */ public function setDomainPrefix( $prefix ) { - list( $dbName, ) = explode( '-', $this->localDomain, 2 ); - $this->localDomain = "{$dbName}-{$prefix}"; + $this->localDomain = new DatabaseDomain( + $this->localDomain->getDatabase(), + null, + $prefix + ); $this->forEachLB( function( LoadBalancer $lb ) use ( $prefix ) { $lb->setDomainPrefix( $prefix ); diff --git a/includes/libs/rdbms/loadbalancer/LoadBalancer.php b/includes/libs/rdbms/loadbalancer/LoadBalancer.php index 57c905facc..3c5d9b19a3 100644 --- a/includes/libs/rdbms/loadbalancer/LoadBalancer.php +++ b/includes/libs/rdbms/loadbalancer/LoadBalancer.php @@ -84,8 +84,10 @@ class LoadBalancer implements ILoadBalancer { private $trxRoundId = false; /** @var array[] Map of (name => callable) */ private $trxRecurringCallbacks = []; - /** @var string Local Domain ID and default for selectDB() calls */ + /** @var DatabaseDomain Local Domain ID and default for selectDB() calls */ private $localDomain; + /** @var string Alternate ID string for the domain instead of DatabaseDomain::getId() */ + private $localDomainIdAlias; /** @var string Current server name */ private $host; /** @var bool Whether this PHP instance is for a CLI script */ @@ -113,10 +115,22 @@ class LoadBalancer implements ILoadBalancer { throw new InvalidArgumentException( __CLASS__ . ': missing servers parameter' ); } $this->mServers = $params['servers']; + + $this->localDomain = isset( $params['localDomain'] ) + ? DatabaseDomain::newFromId( $params['localDomain'] ) + : DatabaseDomain::newUnspecified(); + // In case a caller assumes that the domain ID is simply -, which is almost + // always true, gracefully handle the case when they fail to account for escaping. + if ( $this->localDomain->getTablePrefix() != '' ) { + $this->localDomainIdAlias = + $this->localDomain->getDatabase() . '-' . $this->localDomain->getTablePrefix(); + } else { + $this->localDomainIdAlias = $this->localDomain->getDatabase(); + } + $this->mWaitTimeout = isset( $params['waitTimeout'] ) ? $params['waitTimeout'] : self::POS_WAIT_TIMEOUT; - $this->localDomain = isset( $params['localDomain'] ) ? $params['localDomain'] : ''; $this->mReadIndex = -1; $this->mConns = [ @@ -514,7 +528,7 @@ class LoadBalancer implements ILoadBalancer { ' with invalid server index' ); } - if ( $domain === $this->localDomain ) { + if ( $this->localDomain->equals( $domain ) || $domain === $this->localDomainIdAlias ) { $domain = false; // local connection requested } @@ -652,7 +666,7 @@ class LoadBalancer implements ILoadBalancer { } public function openConnection( $i, $domain = false ) { - if ( $domain === $this->localDomain ) { + if ( $this->localDomain->equals( $domain ) || $domain === $this->localDomainIdAlias ) { $domain = false; // local connection requested } @@ -708,7 +722,9 @@ class LoadBalancer implements ILoadBalancer { * @return IDatabase */ private function openForeignConnection( $i, $domain ) { - list( $dbName, $prefix ) = explode( '-', $domain, 2 ) + [ '', '' ]; + $domainInstance = DatabaseDomain::newFromId( $domain ); + $dbName = $domainInstance->getDatabase(); + $prefix = $domainInstance->getTablePrefix(); if ( isset( $this->mConns['foreignUsed'][$i][$domain] ) ) { // Reuse an already-used connection @@ -1612,8 +1628,11 @@ class LoadBalancer implements ILoadBalancer { * @since 1.28 */ public function setDomainPrefix( $prefix ) { - list( $dbName, ) = explode( '-', $this->localDomain, 2 ); - $this->localDomain = "{$dbName}-{$prefix}"; + $this->localDomain = new DatabaseDomain( + $this->localDomain->getDatabase(), + null, + $prefix + ); $this->forEachOpenConnection( function ( IDatabase $db ) use ( $prefix ) { $db->tablePrefix( $prefix ); diff --git a/tests/phpunit/includes/db/LBFactoryTest.php b/tests/phpunit/includes/db/LBFactoryTest.php index 5affa9cd92..cac2d6dd97 100644 --- a/tests/phpunit/includes/db/LBFactoryTest.php +++ b/tests/phpunit/includes/db/LBFactoryTest.php @@ -216,4 +216,146 @@ class LBFactoryTest extends MediaWikiTestCase { $cp->shutdownLB( $lb ); $cp->shutdown(); } + + private function newLBFactoryMulti( array $baseOverride = [], array $serverOverride = [] ) { + global $wgDBserver, $wgDBuser, $wgDBpassword, $wgDBname, $wgDBtype; + + return new LBFactoryMulti( $baseOverride + [ + 'sectionsByDB' => [], + 'sectionLoads' => [ + 'DEFAULT' => [ + 'test-db1' => 1, + ], + ], + 'serverTemplate' => $serverOverride + [ + 'dbname' => $wgDBname, + 'user' => $wgDBuser, + 'password' => $wgDBpassword, + 'type' => $wgDBtype, + 'flags' => DBO_DEFAULT + ], + 'hostsByName' => [ + 'test-db1' => $wgDBserver, + ], + 'loadMonitorClass' => 'LoadMonitorNull', + 'localDomain' => wfWikiID() + ] ); + } + + public function testNiceDomains() { + global $wgDBname; + + $factory = $this->newLBFactoryMulti(); + $lb = $factory->getMainLB(); + + $db = $lb->getConnectionRef( DB_MASTER ); + $this->assertEquals( + $wgDBname, + $db->getDomainID() + ); + unset( $db ); + + /** @var DatabaseBase $db */ + $db = $lb->getConnection( DB_MASTER, [], '' ); + + $this->assertEquals( + '', + $db->getDomainID() + ); + + $this->assertEquals( + $db->addIdentifierQuotes( 'page' ), + $db->tableName( 'page' ), + "Correct full table name" + ); + + $this->assertEquals( + $db->addIdentifierQuotes( $wgDBname ) . '.' . $db->addIdentifierQuotes( 'page' ), + $db->tableName( "$wgDBname.page" ), + "Correct full table name" + ); + + $this->assertEquals( + $db->addIdentifierQuotes( 'nice_db' ) . '.' . $db->addIdentifierQuotes( 'page' ), + $db->tableName( 'nice_db.page' ), + "Correct full table name" + ); + + $factory->setDomainPrefix( 'my_' ); + $this->assertEquals( + '', + $db->getDomainID() + ); + $this->assertEquals( + $db->addIdentifierQuotes( 'my_page' ), + $db->tableName( 'page' ), + "Correct full table name" + ); + $this->assertEquals( + $db->addIdentifierQuotes( 'other_nice_db' ) . '.' . $db->addIdentifierQuotes( 'page' ), + $db->tableName( 'other_nice_db.page' ), + "Correct full table name" + ); + + $factory->closeAll(); + $factory->destroy(); + } + + public function testTrickyDomain() { + $dbname = 'unittest-domain'; + $factory = $this->newLBFactoryMulti( + [ 'localDomain' => $dbname ], [ 'dbname' => $dbname ] ); + $lb = $factory->getMainLB(); + /** @var DatabaseBase $db */ + $db = $lb->getConnection( DB_MASTER, [], '' ); + + $this->assertEquals( + '', + $db->getDomainID() + ); + + $this->assertEquals( + $db->addIdentifierQuotes( 'page' ), + $db->tableName( 'page' ), + "Correct full table name" + ); + + $this->assertEquals( + $db->addIdentifierQuotes( $dbname ) . '.' . $db->addIdentifierQuotes( 'page' ), + $db->tableName( "$dbname.page" ), + "Correct full table name" + ); + + $this->assertEquals( + $db->addIdentifierQuotes( 'nice_db' ) . '.' . $db->addIdentifierQuotes( 'page' ), + $db->tableName( 'nice_db.page' ), + "Correct full table name" + ); + + $factory->setDomainPrefix( 'my_' ); + + $this->assertEquals( + $db->addIdentifierQuotes( 'my_page' ), + $db->tableName( 'page' ), + "Correct full table name" + ); + $this->assertEquals( + $db->addIdentifierQuotes( 'other_nice_db' ) . '.' . $db->addIdentifierQuotes( 'page' ), + $db->tableName( 'other_nice_db.page' ), + "Correct full table name" + ); + + \MediaWiki\suppressWarnings(); + $this->assertFalse( $db->selectDB( 'garbage-db' ) ); + \MediaWiki\restoreWarnings(); + + $this->assertEquals( + $db->addIdentifierQuotes( 'garbage-db' ) . '.' . $db->addIdentifierQuotes( 'page' ), + $db->tableName( 'garbage-db.page' ), + "Correct full table name" + ); + + $factory->closeAll(); + $factory->destroy(); + } } diff --git a/tests/phpunit/includes/libs/rdbms/database/DatabaseDomainTest.php b/tests/phpunit/includes/libs/rdbms/database/DatabaseDomainTest.php new file mode 100644 index 0000000000..d13fbf9341 --- /dev/null +++ b/tests/phpunit/includes/libs/rdbms/database/DatabaseDomainTest.php @@ -0,0 +1,69 @@ +setExpectedException( InvalidArgumentException::class ); + } + + $domain = new DatabaseDomain( $db, $schema, $prefix ); + $this->assertInstanceOf( DatabaseDomain::class, $domain ); + $this->assertEquals( $db, $domain->getDatabase() ); + $this->assertEquals( $schema, $domain->getSchema() ); + $this->assertEquals( $prefix, $domain->getTablePrefix() ); + $this->assertEquals( $id, $domain->getId() ); + } + + public static function provideNewFromId() { + return [ + // basic + [ 'foo', 'foo', null, '' ], + // - + [ 'foo-bar', 'foo', null, 'bar' ], + [ 'foo-bar-baz', 'foo', 'bar', 'baz' ], + // ?h -> - + [ 'foo?hbar-baz-baa', 'foo-bar', 'baz', 'baa' ], + // ?? -> ? + [ 'foo??bar-baz-baa', 'foo?bar', 'baz', 'baa' ], + // ? is left alone + [ 'foo?bar-baz-baa', 'foo?bar', 'baz', 'baa' ], + // too many parts + [ 'foo-bar-baz-baa', '', '', '', true ], + ]; + } + + /** + * @dataProvider provideNewFromId + */ + public function testNewFromId( $id, $db, $schema, $prefix, $exception = false ) { + if ( $exception ) { + $this->setExpectedException( InvalidArgumentException::class ); + } + $domain = DatabaseDomain::newFromId( $id ); + $this->assertInstanceOf( DatabaseDomain::class, $domain ); + $this->assertEquals( $db, $domain->getDatabase() ); + $this->assertEquals( $schema, $domain->getSchema() ); + $this->assertEquals( $prefix, $domain->getTablePrefix() ); + } +} -- 2.20.1