fix MSSQL bug on org (#3405)

This commit is contained in:
Lunny Xiao 2018-01-27 09:20:59 -06:00 committed by Lauris BH
parent a0c397df08
commit 97fe773491
28 changed files with 1011 additions and 164 deletions

View file

@ -8,7 +8,6 @@ import (
"errors"
"fmt"
"net/url"
"sort"
"strconv"
"strings"
@ -765,13 +764,18 @@ var (
"YES": true,
"ZONE": true,
}
// DefaultPostgresSchema default postgres schema
DefaultPostgresSchema = "public"
)
type postgres struct {
core.Base
schema string
}
func (db *postgres) Init(d *core.DB, uri *core.Uri, drivername, dataSourceName string) error {
db.schema = DefaultPostgresSchema
return db.Base.Init(d, db, uri, drivername, dataSourceName)
}
@ -923,7 +927,7 @@ func (db *postgres) IsColumnExist(tableName, colName string) (bool, error) {
func (db *postgres) GetColumns(tableName string) ([]string, map[string]*core.Column, error) {
// FIXME: the schema should be replaced by user custom's
args := []interface{}{tableName, "public"}
args := []interface{}{tableName, db.schema}
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, numeric_precision, numeric_precision_radix ,
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey,
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey
@ -1024,8 +1028,7 @@ WHERE c.relkind = 'r'::char AND c.relname = $1 AND s.table_schema = $2 AND f.att
}
func (db *postgres) GetTables() ([]*core.Table, error) {
// FIXME: replace public to user customrize schema
args := []interface{}{"public"}
args := []interface{}{db.schema}
s := fmt.Sprintf("SELECT tablename FROM pg_tables WHERE schemaname = $1")
db.LogSQL(s, args)
@ -1050,8 +1053,7 @@ func (db *postgres) GetTables() ([]*core.Table, error) {
}
func (db *postgres) GetIndexes(tableName string) (map[string]*core.Index, error) {
// FIXME: replace the public schema to user specify schema
args := []interface{}{"public", tableName}
args := []interface{}{db.schema, tableName}
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE schemaname=$1 AND tablename=$2")
db.LogSQL(s, args)
@ -1117,10 +1119,6 @@ func (vs values) Get(k string) (v string) {
return vs[k]
}
func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
}
func parseURL(connstr string) (string, error) {
u, err := url.Parse(connstr)
if err != nil {
@ -1131,46 +1129,18 @@ func parseURL(connstr string) (string, error) {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
var kvs []string
escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
accrue := func(k, v string) {
if v != "" {
kvs = append(kvs, k+"="+escaper.Replace(v))
}
}
if u.User != nil {
v := u.User.Username()
accrue("user", v)
v, _ = u.User.Password()
accrue("password", v)
}
i := strings.Index(u.Host, ":")
if i < 0 {
accrue("host", u.Host)
} else {
accrue("host", u.Host[:i])
accrue("port", u.Host[i+1:])
}
if u.Path != "" {
accrue("dbname", u.Path[1:])
return escaper.Replace(u.Path[1:]), nil
}
q := u.Query()
for k := range q {
accrue(k, q.Get(k))
}
sort.Strings(kvs) // Makes testing easier (not a performance concern)
return strings.Join(kvs, " "), nil
return "", nil
}
func parseOpts(name string, o values) {
func parseOpts(name string, o values) error {
if len(name) == 0 {
return
return fmt.Errorf("invalid options: %s", name)
}
name = strings.TrimSpace(name)
@ -1179,31 +1149,36 @@ func parseOpts(name string, o values) {
for _, p := range ps {
kv := strings.Split(p, "=")
if len(kv) < 2 {
errorf("invalid option: %q", p)
return fmt.Errorf("invalid option: %q", p)
}
o.Set(kv[0], kv[1])
}
return nil
}
func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.POSTGRES}
o := make(values)
var err error
if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
dataSourceName, err = parseURL(dataSourceName)
db.DbName, err = parseURL(dataSourceName)
if err != nil {
return nil, err
}
} else {
o := make(values)
err = parseOpts(dataSourceName, o)
if err != nil {
return nil, err
}
}
parseOpts(dataSourceName, o)
db.DbName = o.Get("dbname")
db.DbName = o.Get("dbname")
}
if db.DbName == "" {
return nil, errors.New("dbname is empty")
}
/*db.Schema = o.Get("schema")
if len(db.Schema) == 0 {
db.Schema = "public"
}*/
return db, nil
}