Global refactor

This commit is contained in:
ycc 2021-10-28 10:34:10 +02:00
parent 6f1cb9a2d0
commit 7f02683866
5 changed files with 192 additions and 135 deletions

21
LICENSE
View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2021 redr00m
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,2 +0,0 @@
# sqldb
Generic SQL access libradry

2
go.mod
View File

@ -1,4 +1,4 @@
module github.com/redr00m/sqldb module forge.redroom.link/yves/sqldb
go 1.15 go 1.15

190
pg.go
View File

@ -10,48 +10,72 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
var db *sql.DB type Db struct {
Driver string
Url string
conn *sql.DB
}
// AssRow : associative row type // AssRow : associative row type
type AssRow map[string]interface{} type AssRow map[string]interface{}
// Select Result
type Rows []AssRow
// Table is a table structure description // Table is a table structure description
type Table struct { type TableInfo struct {
Name string `json:"name"` Name string `json:"name"`
Columns map[string]string `json:"columns"` Columns map[string]string `json:"columns"`
db *Db
} }
// Open the database // Open the database
func Open(driver string, url string) { func Open(driver string, url string) *Db {
var database Db
var err error var err error
db, err = sql.Open(driver, url) database.Driver = driver
database.Url = url
database.conn, err = sql.Open(driver, url)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
} }
return &database
} }
// Close the database connection // Close the database connection
func Close() { func (db *Db) Close() {
db.Close() db.conn.Close()
}
func (db *Db) Table(name string) *TableInfo {
var ti TableInfo
ti.Name = name
ti.db = db
return &ti
} }
// GetAssociativeArray : Provide results as an associative array // GetAssociativeArray : Provide results as an associative array
func GetAssociativeArray(table string, columns []string, restriction string, sortkeys []string, dir string) []AssRow { func (t *TableInfo) GetAssociativeArray(columns []string, restriction string, sortkeys []string, dir string) ([]AssRow, error) {
return QueryAssociativeArray(buildSelect(table, "", columns, restriction, sortkeys, dir)) return t.db.QueryAssociativeArray(t.buildSelect("", columns, restriction, sortkeys, dir))
} }
// QueryAssociativeArray : Provide results as an associative array // QueryAssociativeArray : Provide results as an associative array
func QueryAssociativeArray(query string) []AssRow { func (db *Db) QueryAssociativeArray(query string) (Rows, error) {
rows, err := db.Query(query) rows, err := db.conn.Query(query)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
log.Println(query) log.Println(query)
return nil, err
} }
defer rows.Close() defer rows.Close()
var results []AssRow var results Rows
cols, _ := rows.Columns() cols, err := rows.Columns()
if err != nil {
log.Println(err)
log.Println(query)
return nil, err
}
for rows.Next() { for rows.Next() {
// Create a slice of interface{}'s to represent each column, // Create a slice of interface{}'s to represent each column,
// and a second slice to contain pointers to each item in the columns slice. // and a second slice to contain pointers to each item in the columns slice.
@ -70,23 +94,26 @@ func QueryAssociativeArray(query string) []AssRow {
m := make(AssRow) m := make(AssRow)
for i, colName := range cols { for i, colName := range cols {
val := columnPointers[i].(*interface{}) val := columnPointers[i].(*interface{})
m[colName] = *val m[colName] = fmt.Sprintf("%v", *val)
} }
// jsonString, _ := json.Marshal(m)
// Outputs: map[columnName:value columnName2:value2 columnName3:value3 ...]
// fmt.Println(string(jsonString))
results = append(results, m) results = append(results, m)
} }
return results return results, nil
} }
// GetSchema : Provide results as an associative array // GetSchema : Provide results as an associative array
func GetSchema(table string) Table { func (t *TableInfo) GetSchema() (*TableInfo, error) {
t := Table{Name: table} var ti TableInfo
cols := QueryAssociativeArray("SELECT column_name :: varchar as name, REPLACE(REPLACE(data_type,'character varying','varchar'),'character','char') || COALESCE('(' || character_maximum_length || ')', '') as type from INFORMATION_SCHEMA.COLUMNS where table_name ='" + table + "';") ti.Name = t.Name
t.Columns = make(map[string]string) ti.db = t.db
cols, err := t.db.QueryAssociativeArray("SELECT column_name :: varchar as name, REPLACE(REPLACE(data_type,'character varying','varchar'),'character','char') || COALESCE('(' || character_maximum_length || ')', '') as type from INFORMATION_SCHEMA.COLUMNS where table_name ='" + t.Name + "';")
if err != nil {
log.Println(err)
return nil, err
}
ti.Columns = make(map[string]string)
for _, row := range cols { for _, row := range cols {
var name, rowtype string var name, rowtype string
for key, element := range row { for key, element := range row {
@ -97,16 +124,17 @@ func GetSchema(table string) Table {
rowtype = fmt.Sprintf("%v", element) rowtype = fmt.Sprintf("%v", element)
} }
} }
t.Columns[name] = rowtype ti.Columns[name] = rowtype
} }
return t return &ti, nil
} }
func ListTables() []AssRow { func (db *Db) ListTables() (Rows, error) {
return QueryAssociativeArray("SELECT table_name :: varchar FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name;") return db.QueryAssociativeArray("SELECT table_name :: varchar FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name;")
} }
func CreateTable(t Table) { func (db *Db) CreateTable(t TableInfo) error {
t.db = db
query := "create table " + t.Name + " ( " query := "create table " + t.Name + " ( "
columns := "" columns := ""
for name, rowtype := range t.Columns { for name, rowtype := range t.Columns {
@ -120,57 +148,61 @@ func CreateTable(t Table) {
} }
query += columns query += columns
query = query[:len(query)-1] + " )" query = query[:len(query)-1] + " )"
_, err := db.Query(query) _, err := t.db.conn.Query(query)
if err != nil { if err != nil {
log.Println(query) log.Println(err.Error())
} return err
query = "create sequence if not exists sq_" + t.Name
_, err = db.Query(query)
if err != nil {
log.Println(query)
} }
return nil
} }
func DeleteTable(table string) { func (t *TableInfo) DeleteTable() error {
query := "drop table " + table query := "drop table " + t.Name
_, err := db.Query(query) _, err := t.db.conn.Query(query)
if err != nil { if err != nil {
log.Println(err) log.Println(err.Error())
return err
} }
query = "drop sequence if exists sq_" + table query = "drop sequence if exists sq_" + t.Name
_, err = db.Query(query) _, err = t.db.conn.Query(query)
if err != nil { if err != nil {
log.Println(err) log.Println(err.Error())
return err
} }
return nil
} }
func AddColumn(table string, name string, sqltype string) { func (t *TableInfo) AddColumn(name string, sqltype string) error {
query := "alter table " + table + " add " + name + " " + sqltype query := "alter table " + t.Name + " add " + name + " " + sqltype
rows, err := db.Query(query) rows, err := t.db.conn.Query(query)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return err
} }
defer rows.Close() defer rows.Close()
return nil
} }
func DeleteColumn(table string, name string) { func (t *TableInfo) DeleteColumn(name string) error {
query := "alter table " + table + " drop " + name query := "alter table " + t.Name + " drop " + name
rows, err := db.Query(query) rows, err := t.db.conn.Query(query)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return err
} }
defer rows.Close() defer rows.Close()
return nil
} }
func ListSequences() []AssRow { func (db *Db) ListSequences() (Rows, error) {
return QueryAssociativeArray("SELECT sequence_name :: varchar FROM information_schema.sequences WHERE sequence_schema = 'public' ORDER BY sequence_name;") return db.QueryAssociativeArray("SELECT sequence_name :: varchar FROM information_schema.sequences WHERE sequence_schema = 'public' ORDER BY sequence_name;")
} }
func buildSelect(table string, key string, columns []string, restriction string, sortkeys []string, dir ...string) string { func (t *TableInfo) buildSelect(key string, columns []string, restriction string, sortkeys []string, dir ...string) string {
if key != "" { if key != "" {
columns = append(columns, key) columns = append(columns, key)
} }
query := "select " + strings.Join(columns, ",") + " from " + table query := "select " + strings.Join(columns, ",") + " from " + t.Name
if restriction != "" { if restriction != "" {
query += " where " + restriction query += " where " + restriction
} }
@ -183,17 +215,21 @@ func buildSelect(table string, key string, columns []string, restriction string,
return query return query
} }
func Insert(table string, record AssRow) int { func (t *TableInfo) Insert(record AssRow) (int, error) {
columns := "" columns := ""
values := "" values := ""
schema := GetSchema(table) t, err := t.GetSchema()
if err != nil {
log.Println(err)
return -1, err
}
var id int var id int
for key, element := range record { for key, element := range record {
if strings.Contains(schema.Columns[key], "char") || strings.Contains(schema.Columns[key], "date") { if strings.Contains(t.Columns[key], "char") || strings.Contains(t.Columns[key], "date") {
columns += key + "," columns += key + ","
values += fmt.Sprintf(pq.QuoteLiteral(fmt.Sprintf("%v", element))) + "," values += fmt.Sprint(pq.QuoteLiteral(fmt.Sprintf("%v", element))) + ","
} else { } else {
columns += key + "," columns += key + ","
@ -201,19 +237,23 @@ func Insert(table string, record AssRow) int {
} }
} }
db.QueryRow("INSERT INTO " + table + "(" + removeLastChar(columns) + ") VALUES (" + removeLastChar(values) + ") RETURNING id").Scan(&id) t.db.conn.QueryRow("INSERT INTO " + t.Name + "(" + removeLastChar(columns) + ") VALUES (" + removeLastChar(values) + ") RETURNING id").Scan(&id)
return id return id, nil
} }
func Update(table string, record AssRow) string { func (t *TableInfo) Update(record AssRow) error {
schema := GetSchema(table) t, err := t.GetSchema()
if err != nil {
log.Println(err)
return err
}
id := "" id := ""
stack := "" stack := ""
for key, element := range record { for key, element := range record {
if strings.Contains(schema.Columns[key], "char") || strings.Contains(schema.Columns[key], "date") { if strings.Contains(t.Columns[key], "char") || strings.Contains(t.Columns[key], "date") {
stack = stack + " " + key + " = " + pq.QuoteLiteral(fmt.Sprintf("%v", element)) + "," stack = stack + " " + key + " = " + pq.QuoteLiteral(fmt.Sprintf("%v", element)) + ","
@ -227,17 +267,18 @@ func Update(table string, record AssRow) string {
} }
} }
stack = removeLastChar(stack) stack = removeLastChar(stack)
query := ("UPDATE " + table + " SET " + stack + " WHERE id = " + id) query := ("UPDATE " + t.Name + " SET " + stack + " WHERE id = " + id)
rows, err := db.Query(query) rows, err := t.db.conn.Query(query)
if err != nil { if err != nil {
log.Println(query) log.Println(query)
log.Println(err) log.Println(err)
return err
} }
defer rows.Close() defer rows.Close()
return query return nil
} }
func Delete(table string, record AssRow) string { func (t *TableInfo) Delete(record AssRow) error {
id := "" id := ""
values := "" values := ""
@ -248,31 +289,32 @@ func Delete(table string, record AssRow) string {
} }
} }
query := ("DELETE FROM " + table + " WHERE id = " + id) query := ("DELETE FROM " + t.Name + " WHERE id = " + id)
rows, err := db.Query(query) rows, err := t.db.conn.Query(query)
if err != nil { if err != nil {
log.Println(query) log.Println(query)
log.Println(err) log.Println(err)
return err
} }
defer rows.Close() defer rows.Close()
return query return nil
} }
func UpdateOrInsert(table string, record AssRow) int { func (t *TableInfo) UpdateOrInsert(record AssRow) (int, error) {
id := 0 id := -1
for key, element := range record { for key, element := range record {
if key == "id" { if key == "id" {
sid := fmt.Sprintf("%v", element) sid := fmt.Sprintf("%v", element)
id, _ = strconv.Atoi(sid) id, _ = strconv.Atoi(sid)
} }
} }
if id == 0 { if id == -1 {
return Insert(table, record) return t.Insert(record)
} else { } else {
Update(table, record) t.Update(record)
return id return id, nil
} }
} }
func removeLastChar(s string) string { func removeLastChar(s string) string {

View File

@ -9,8 +9,8 @@ import (
) )
func TestCreateTable(t *testing.T) { func TestCreateTable(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close() defer db.Close()
jsonFile, err := os.Open("test_table.json") jsonFile, err := os.Open("test_table.json")
if err != nil { if err != nil {
@ -20,25 +20,37 @@ func TestCreateTable(t *testing.T) {
byteValue, _ := ioutil.ReadAll(jsonFile) byteValue, _ := ioutil.ReadAll(jsonFile)
var data Table var jsonSource TableInfo
json.Unmarshal([]byte(byteValue), &data) json.Unmarshal([]byte(byteValue), &jsonSource)
CreateTable(data) err = db.CreateTable(jsonSource)
if err != nil {
fmt.Println(err.Error())
}
tbl := GetSchema(data.Name) sch, err := db.Table("test").GetSchema()
if len(tbl.Columns) == 0 { if err != nil {
fmt.Println(err.Error())
}
if len(sch.Columns) == 0 {
t.Errorf("Create table failed") t.Errorf("Create table failed")
} }
} }
func TestAddColumn(t *testing.T) { func TestAddColumn(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close() defer db.Close()
old := GetSchema("test") old, err := db.Table("test").GetSchema()
AddColumn("test", "addcolumn", "integer") if err != nil {
new := GetSchema("test") fmt.Println(err.Error())
}
db.Table("test").AddColumn("addcolumn", "integer")
new, err := db.Table("test").GetSchema()
if err != nil {
fmt.Println(err.Error())
}
if len(old.Columns) == len(new.Columns) { if len(old.Columns) == len(new.Columns) {
t.Errorf("Column already exist") t.Errorf("Column already exist")
@ -47,20 +59,26 @@ func TestAddColumn(t *testing.T) {
func TestInsert(t *testing.T) { func TestInsert(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close() defer db.Close()
vl := make(AssRow) vl := make(AssRow)
vl["name"] = "toto" vl["name"] = "toto"
vl["description"] = "tata" vl["description"] = "tata"
old := GetAssociativeArray("test", []string{"*"}, "", []string{}, "") old, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "")
if err != nil {
fmt.Println(err.Error())
}
jsonStringOld, _ := json.Marshal(old) jsonStringOld, _ := json.Marshal(old)
fmt.Println(string(jsonStringOld)) fmt.Println(string(jsonStringOld))
UpdateOrInsert("test", vl) db.Table("test").UpdateOrInsert(vl)
new := GetAssociativeArray("test", []string{"*"}, "", []string{}, "") new, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "")
if err != nil {
fmt.Println(err.Error())
}
jsonStringNew, _ := json.Marshal(new) jsonStringNew, _ := json.Marshal(new)
fmt.Println(string(jsonStringNew)) fmt.Println(string(jsonStringNew))
@ -71,21 +89,27 @@ func TestInsert(t *testing.T) {
func TestUpdate(t *testing.T) { func TestUpdate(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close() defer db.Close()
vl := make(AssRow) vl := make(AssRow)
vl["id"] = 1 vl["id"] = 1
vl["name"] = "titi" vl["name"] = "titi"
vl["description"] = "toto" vl["description"] = "toto"
old := GetAssociativeArray("test", []string{"*"}, "", []string{}, "") old, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "")
if err != nil {
fmt.Println(err.Error())
}
jsonStringOld, _ := json.Marshal(old) jsonStringOld, _ := json.Marshal(old)
fmt.Println(string(jsonStringOld)) fmt.Println(string(jsonStringOld))
UpdateOrInsert("test", vl) db.Table("test").UpdateOrInsert(vl)
new := GetAssociativeArray("test", []string{"*"}, "", []string{}, "") new, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "")
if err != nil {
fmt.Println(err.Error())
}
jsonStringNew, _ := json.Marshal(new) jsonStringNew, _ := json.Marshal(new)
fmt.Println(string(jsonStringNew)) fmt.Println(string(jsonStringNew))
@ -97,19 +121,25 @@ func TestUpdate(t *testing.T) {
func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close() defer db.Close()
vl := make(AssRow) vl := make(AssRow)
vl["id"] = 1 vl["id"] = 1
old := GetAssociativeArray("test", []string{"*"}, "", []string{}, "") old, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "")
if err != nil {
fmt.Println(err.Error())
}
jsonStringOld, _ := json.Marshal(old) jsonStringOld, _ := json.Marshal(old)
fmt.Println(string(jsonStringOld)) fmt.Println(string(jsonStringOld))
Delete("test", vl) db.Table("test").Delete(vl)
new := GetAssociativeArray("test", []string{"*"}, "", []string{}, "") new, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "")
if err != nil {
fmt.Println(err.Error())
}
jsonStringNew, _ := json.Marshal(new) jsonStringNew, _ := json.Marshal(new)
fmt.Println(string(jsonStringNew)) fmt.Println(string(jsonStringNew))
@ -120,12 +150,18 @@ func TestDelete(t *testing.T) {
func TestDeleteColumn(t *testing.T) { func TestDeleteColumn(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close() defer db.Close()
old := GetSchema("test") old, err := db.Table("test").GetSchema()
DeleteColumn("test", "addcolumn") if err != nil {
new := GetSchema("test") fmt.Println(err.Error())
}
db.Table("test").DeleteColumn("addcolumn")
new, err := db.Table("test").GetSchema()
if err != nil {
fmt.Println(err.Error())
}
if len(old.Columns) == len(new.Columns) { if len(old.Columns) == len(new.Columns) {
t.Errorf("Error column not deleted") t.Errorf("Error column not deleted")
@ -133,13 +169,15 @@ func TestDeleteColumn(t *testing.T) {
} }
func TestDeleteTable(t *testing.T) { func TestDeleteTable(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close() defer db.Close()
DeleteTable("test") db.Table("test").DeleteTable()
tbl := GetSchema("test")
tbl, err := db.Table("test").GetSchema()
if err != nil {
fmt.Println(err.Error())
}
if len(tbl.Columns) != 0 { if len(tbl.Columns) != 0 {
t.Errorf("Delete table failed") t.Errorf("Delete table failed")
} }