mysql OK, partially tested
This commit is contained in:
153
db.go
153
db.go
@@ -9,6 +9,7 @@ import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -79,7 +80,7 @@ func (db *Db) QueryAssociativeArray(query string) (Rows, error) {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// get rows
|
||||
results := Rows{}
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
@@ -87,6 +88,16 @@ func (db *Db) QueryAssociativeArray(query string) (Rows, error) {
|
||||
log.Println(query)
|
||||
return nil, err
|
||||
}
|
||||
// make types map
|
||||
columnTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
columnType := make(map[string]string)
|
||||
for _, colType := range columnTypes {
|
||||
columnType[colType.Name()] = colType.DatabaseTypeName()
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
// Create a slice of interface{}'s to represent each column,
|
||||
// and a second slice to contain pointers to each item in the columns slice.
|
||||
@@ -97,19 +108,69 @@ func (db *Db) QueryAssociativeArray(query string) (Rows, error) {
|
||||
}
|
||||
|
||||
// Scan the result into the column pointers...
|
||||
if err := rows.Scan(columnPointers...); err != nil {
|
||||
|
||||
err = rows.Scan(columnPointers...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create our map, and retrieve the value for each column from the pointers slice,
|
||||
// storing it in the map with the name of the column as the key.
|
||||
m := make(AssRow)
|
||||
for i, colName := range cols {
|
||||
//fmt.Println(colName)
|
||||
val := columnPointers[i].(*interface{})
|
||||
m[colName] = fmt.Sprintf("%v", *val)
|
||||
if db.Driver == "mysql" {
|
||||
if (*val) == nil {
|
||||
m[colName] = nil
|
||||
} else {
|
||||
switch columnType[colName] {
|
||||
case "INT", "BIGINT":
|
||||
i, err := strconv.ParseInt(fmt.Sprintf("%s", *val), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m[colName] = i
|
||||
case "UNSIGNED BIGINT":
|
||||
u, err := strconv.ParseUint(fmt.Sprintf("%s", *val), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m[colName] = u
|
||||
case "FLOAT":
|
||||
f, err := strconv.ParseFloat(fmt.Sprintf("%s", *val), 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m[colName] = f
|
||||
case "TINYINT":
|
||||
i, err := strconv.ParseInt(fmt.Sprintf("%s", *val), 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if i == 1 {
|
||||
m[colName] = true
|
||||
} else {
|
||||
m[colName] = false
|
||||
}
|
||||
|
||||
case "VARCHAR", "TEXT", "TIMESTAMP":
|
||||
m[colName] = fmt.Sprintf("%s", *val)
|
||||
default:
|
||||
if reflect.ValueOf(val).IsNil() {
|
||||
m[colName] = nil
|
||||
} else {
|
||||
fmt.Printf("Unknow type : %s", columnType[colName])
|
||||
m[colName] = fmt.Sprintf("%v", *val)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if db.Driver == "postgres" {
|
||||
m[colName] = *val
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
results = append(results, m)
|
||||
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
@@ -117,7 +178,7 @@ func (db *Db) QueryAssociativeArray(query string) (Rows, error) {
|
||||
// GetSchema : Provide table schema as an associative array
|
||||
func (t *TableInfo) GetSchema() (*TableInfo, error) {
|
||||
pgSchema := "SELECT column_name :: varchar as name, REPLACE(REPLACE(data_type,'character varying','varchar'),'character','char') || COALESCE('(' || character_maximum_length || ')', '') as type, col_description('public." + t.Name + "'::regclass, ordinal_position) as comment from INFORMATION_SCHEMA.COLUMNS where table_name ='" + t.Name + "';"
|
||||
mySchema := "SELECT COLUMN_NAME as name, DATA_TYPE || CHARACTER_MAXIMUM_LENGTH FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '" + t.Name + "';"
|
||||
mySchema := "SELECT COLUMN_NAME as name, CONCAT(DATA_TYPE, COALESCE(CONCAT('(' , CHARACTER_MAXIMUM_LENGTH, ')'), '')) as type FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '" + t.Name + "';"
|
||||
var schemaQuery string
|
||||
var ti TableInfo
|
||||
ti.Name = t.Name
|
||||
@@ -138,13 +199,13 @@ func (t *TableInfo) GetSchema() (*TableInfo, error) {
|
||||
var name, rowtype, comment string
|
||||
for key, element := range row {
|
||||
if key == "name" {
|
||||
name = fmt.Sprintf("%v", element)
|
||||
name = fmt.Sprintf("%s", element)
|
||||
}
|
||||
if key == "type" {
|
||||
rowtype = fmt.Sprintf("%v", element)
|
||||
rowtype = fmt.Sprintf("%s", element)
|
||||
}
|
||||
if key == "comment" {
|
||||
comment = fmt.Sprintf("%v", element)
|
||||
comment = fmt.Sprintf("%s", element)
|
||||
}
|
||||
}
|
||||
ti.Columns[name] = rowtype
|
||||
@@ -200,6 +261,16 @@ func (db *Db) myListTables() (Rows, error) {
|
||||
}
|
||||
|
||||
func (db *Db) CreateTable(t TableInfo) error {
|
||||
if db.Driver == "postgres" {
|
||||
return db.pgCreateTable(t)
|
||||
}
|
||||
if db.Driver == "mysql" {
|
||||
return db.myCreateTable(t)
|
||||
}
|
||||
return errors.New("no driver")
|
||||
}
|
||||
|
||||
func (db *Db) pgCreateTable(t TableInfo) error {
|
||||
t.db = db
|
||||
query := "create table " + t.Name + " ( "
|
||||
columns := ""
|
||||
@@ -234,6 +305,32 @@ func (db *Db) CreateTable(t TableInfo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Db) myCreateTable(t TableInfo) error {
|
||||
t.db = db
|
||||
query := "create table " + t.Name + " ( "
|
||||
columns := ""
|
||||
for name, rowtype := range t.Columns {
|
||||
if fmt.Sprintf("%v", name) == "id" {
|
||||
columns += fmt.Sprintf("%v", name) + " " + "SERIAL PRIMARY KEY,"
|
||||
} else {
|
||||
desc := strings.Split(fmt.Sprintf("%v", rowtype), "|")
|
||||
columns += fmt.Sprintf("%v", name) + " " + desc[0]
|
||||
if len(desc) > 1 {
|
||||
columns += " COMMENT " + pq.QuoteLiteral(desc[1])
|
||||
}
|
||||
columns += ","
|
||||
}
|
||||
}
|
||||
query += columns
|
||||
query = query[:len(query)-1] + " )"
|
||||
_, err := t.db.conn.Query(query)
|
||||
if err != nil {
|
||||
log.Println(err.Error())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TableInfo) DeleteTable() error {
|
||||
query := "drop table " + t.Name
|
||||
_, err := t.db.conn.Query(query)
|
||||
@@ -251,6 +348,16 @@ func (t *TableInfo) DeleteTable() error {
|
||||
}
|
||||
|
||||
func (t *TableInfo) AddColumn(name string, sqltype string, comment string) error {
|
||||
if t.db.Driver == "postgres" {
|
||||
return t.pgAddColumn(name, sqltype, comment)
|
||||
}
|
||||
if t.db.Driver == "mysql" {
|
||||
return t.myAddColumn(name, sqltype, comment)
|
||||
}
|
||||
return errors.New("no driver")
|
||||
}
|
||||
|
||||
func (t *TableInfo) pgAddColumn(name string, sqltype string, comment string) error {
|
||||
query := "alter table " + t.Name + " add " + name + " " + sqltype
|
||||
rows, err := t.db.conn.Query(query)
|
||||
if err != nil {
|
||||
@@ -269,6 +376,20 @@ func (t *TableInfo) AddColumn(name string, sqltype string, comment string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TableInfo) myAddColumn(name string, sqltype string, comment string) error {
|
||||
query := "alter table " + t.Name + " add " + name + " " + sqltype
|
||||
if strings.TrimSpace(comment) != "" {
|
||||
query += " COMMENT " + pq.QuoteLiteral(comment)
|
||||
}
|
||||
rows, err := t.db.conn.Query(query)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TableInfo) DeleteColumn(name string) error {
|
||||
query := "alter table " + t.Name + " drop " + name
|
||||
rows, err := t.db.conn.Query(query)
|
||||
@@ -355,8 +476,16 @@ func (t *TableInfo) Insert(record AssRow) (int, error) {
|
||||
columns += key + ","
|
||||
values += FormatForSQL(t.Columns[key], element) + ","
|
||||
}
|
||||
|
||||
err = t.db.conn.QueryRow("INSERT INTO " + t.Name + "(" + removeLastChar(columns) + ") VALUES (" + removeLastChar(values) + ") RETURNING id").Scan(&id)
|
||||
if t.db.Driver == "postgres" {
|
||||
err = t.db.conn.QueryRow("INSERT INTO " + t.Name + "(" + removeLastChar(columns) + ") VALUES (" + removeLastChar(values) + ") RETURNING id").Scan(&id)
|
||||
}
|
||||
if t.db.Driver == "mysql" {
|
||||
_, err = t.db.conn.Query("INSERT INTO " + t.Name + "(" + removeLastChar(columns) + ") VALUES (" + removeLastChar(values) + ")")
|
||||
if err != nil {
|
||||
return id, err
|
||||
}
|
||||
err = t.db.conn.QueryRow("LAST_INSERT_ID();").Scan(&id)
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user