Global refactor

This commit is contained in:
ycc 2021-10-28 10:34:10 +02:00
parent 6f1cb9a2d0
commit 7f02683866
5 changed files with 192 additions and 135 deletions

21
LICENSE
View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2021 redr00m
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,2 +0,0 @@
# sqldb
Generic SQL access libradry

2
go.mod
View File

@ -1,4 +1,4 @@
module github.com/redr00m/sqldb
module forge.redroom.link/yves/sqldb
go 1.15

190
pg.go
View File

@ -10,48 +10,72 @@ import (
"github.com/lib/pq"
)
var db *sql.DB
type Db struct {
Driver string
Url string
conn *sql.DB
}
// AssRow : associative row type
type AssRow map[string]interface{}
// Select Result
type Rows []AssRow
// Table is a table structure description
type Table struct {
type TableInfo struct {
Name string `json:"name"`
Columns map[string]string `json:"columns"`
db *Db
}
// Open the database
func Open(driver string, url string) {
func Open(driver string, url string) *Db {
var database Db
var err error
db, err = sql.Open(driver, url)
database.Driver = driver
database.Url = url
database.conn, err = sql.Open(driver, url)
if err != nil {
log.Println(err)
}
return &database
}
// Close the database connection
func Close() {
db.Close()
func (db *Db) Close() {
db.conn.Close()
}
func (db *Db) Table(name string) *TableInfo {
var ti TableInfo
ti.Name = name
ti.db = db
return &ti
}
// GetAssociativeArray : Provide results as an associative array
func GetAssociativeArray(table string, columns []string, restriction string, sortkeys []string, dir string) []AssRow {
return QueryAssociativeArray(buildSelect(table, "", columns, restriction, sortkeys, dir))
func (t *TableInfo) GetAssociativeArray(columns []string, restriction string, sortkeys []string, dir string) ([]AssRow, error) {
return t.db.QueryAssociativeArray(t.buildSelect("", columns, restriction, sortkeys, dir))
}
// QueryAssociativeArray : Provide results as an associative array
func QueryAssociativeArray(query string) []AssRow {
rows, err := db.Query(query)
func (db *Db) QueryAssociativeArray(query string) (Rows, error) {
rows, err := db.conn.Query(query)
if err != nil {
log.Println(err)
log.Println(query)
return nil, err
}
defer rows.Close()
var results []AssRow
cols, _ := rows.Columns()
var results Rows
cols, err := rows.Columns()
if err != nil {
log.Println(err)
log.Println(query)
return nil, err
}
for rows.Next() {
// Create a slice of interface{}'s to represent each column,
// and a second slice to contain pointers to each item in the columns slice.
@ -70,23 +94,26 @@ func QueryAssociativeArray(query string) []AssRow {
m := make(AssRow)
for i, colName := range cols {
val := columnPointers[i].(*interface{})
m[colName] = *val
m[colName] = fmt.Sprintf("%v", *val)
}
// jsonString, _ := json.Marshal(m)
// Outputs: map[columnName:value columnName2:value2 columnName3:value3 ...]
// fmt.Println(string(jsonString))
results = append(results, m)
}
return results
return results, nil
}
// GetSchema : Provide results as an associative array
func GetSchema(table string) Table {
t := Table{Name: table}
cols := QueryAssociativeArray("SELECT column_name :: varchar as name, REPLACE(REPLACE(data_type,'character varying','varchar'),'character','char') || COALESCE('(' || character_maximum_length || ')', '') as type from INFORMATION_SCHEMA.COLUMNS where table_name ='" + table + "';")
t.Columns = make(map[string]string)
func (t *TableInfo) GetSchema() (*TableInfo, error) {
var ti TableInfo
ti.Name = t.Name
ti.db = t.db
cols, err := t.db.QueryAssociativeArray("SELECT column_name :: varchar as name, REPLACE(REPLACE(data_type,'character varying','varchar'),'character','char') || COALESCE('(' || character_maximum_length || ')', '') as type from INFORMATION_SCHEMA.COLUMNS where table_name ='" + t.Name + "';")
if err != nil {
log.Println(err)
return nil, err
}
ti.Columns = make(map[string]string)
for _, row := range cols {
var name, rowtype string
for key, element := range row {
@ -97,16 +124,17 @@ func GetSchema(table string) Table {
rowtype = fmt.Sprintf("%v", element)
}
}
t.Columns[name] = rowtype
ti.Columns[name] = rowtype
}
return t
return &ti, nil
}
func ListTables() []AssRow {
return QueryAssociativeArray("SELECT table_name :: varchar FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name;")
func (db *Db) ListTables() (Rows, error) {
return db.QueryAssociativeArray("SELECT table_name :: varchar FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name;")
}
func CreateTable(t Table) {
func (db *Db) CreateTable(t TableInfo) error {
t.db = db
query := "create table " + t.Name + " ( "
columns := ""
for name, rowtype := range t.Columns {
@ -120,57 +148,61 @@ func CreateTable(t Table) {
}
query += columns
query = query[:len(query)-1] + " )"
_, err := db.Query(query)
_, err := t.db.conn.Query(query)
if err != nil {
log.Println(query)
}
query = "create sequence if not exists sq_" + t.Name
_, err = db.Query(query)
if err != nil {
log.Println(query)
log.Println(err.Error())
return err
}
return nil
}
func DeleteTable(table string) {
query := "drop table " + table
_, err := db.Query(query)
func (t *TableInfo) DeleteTable() error {
query := "drop table " + t.Name
_, err := t.db.conn.Query(query)
if err != nil {
log.Println(err)
log.Println(err.Error())
return err
}
query = "drop sequence if exists sq_" + table
_, err = db.Query(query)
query = "drop sequence if exists sq_" + t.Name
_, err = t.db.conn.Query(query)
if err != nil {
log.Println(err)
log.Println(err.Error())
return err
}
return nil
}
func AddColumn(table string, name string, sqltype string) {
query := "alter table " + table + " add " + name + " " + sqltype
rows, err := db.Query(query)
func (t *TableInfo) AddColumn(name string, sqltype string) error {
query := "alter table " + t.Name + " add " + name + " " + sqltype
rows, err := t.db.conn.Query(query)
if err != nil {
log.Println(err)
return err
}
defer rows.Close()
return nil
}
func DeleteColumn(table string, name string) {
query := "alter table " + table + " drop " + name
rows, err := db.Query(query)
func (t *TableInfo) DeleteColumn(name string) error {
query := "alter table " + t.Name + " drop " + name
rows, err := t.db.conn.Query(query)
if err != nil {
log.Println(err)
return err
}
defer rows.Close()
return nil
}
func ListSequences() []AssRow {
return QueryAssociativeArray("SELECT sequence_name :: varchar FROM information_schema.sequences WHERE sequence_schema = 'public' ORDER BY sequence_name;")
func (db *Db) ListSequences() (Rows, error) {
return db.QueryAssociativeArray("SELECT sequence_name :: varchar FROM information_schema.sequences WHERE sequence_schema = 'public' ORDER BY sequence_name;")
}
func buildSelect(table string, key string, columns []string, restriction string, sortkeys []string, dir ...string) string {
func (t *TableInfo) buildSelect(key string, columns []string, restriction string, sortkeys []string, dir ...string) string {
if key != "" {
columns = append(columns, key)
}
query := "select " + strings.Join(columns, ",") + " from " + table
query := "select " + strings.Join(columns, ",") + " from " + t.Name
if restriction != "" {
query += " where " + restriction
}
@ -183,17 +215,21 @@ func buildSelect(table string, key string, columns []string, restriction string,
return query
}
func Insert(table string, record AssRow) int {
func (t *TableInfo) Insert(record AssRow) (int, error) {
columns := ""
values := ""
schema := GetSchema(table)
t, err := t.GetSchema()
if err != nil {
log.Println(err)
return -1, err
}
var id int
for key, element := range record {
if strings.Contains(schema.Columns[key], "char") || strings.Contains(schema.Columns[key], "date") {
if strings.Contains(t.Columns[key], "char") || strings.Contains(t.Columns[key], "date") {
columns += key + ","
values += fmt.Sprintf(pq.QuoteLiteral(fmt.Sprintf("%v", element))) + ","
values += fmt.Sprint(pq.QuoteLiteral(fmt.Sprintf("%v", element))) + ","
} else {
columns += key + ","
@ -201,19 +237,23 @@ func Insert(table string, record AssRow) int {
}
}
db.QueryRow("INSERT INTO " + table + "(" + removeLastChar(columns) + ") VALUES (" + removeLastChar(values) + ") RETURNING id").Scan(&id)
return id
t.db.conn.QueryRow("INSERT INTO " + t.Name + "(" + removeLastChar(columns) + ") VALUES (" + removeLastChar(values) + ") RETURNING id").Scan(&id)
return id, nil
}
func Update(table string, record AssRow) string {
func (t *TableInfo) Update(record AssRow) error {
schema := GetSchema(table)
t, err := t.GetSchema()
if err != nil {
log.Println(err)
return err
}
id := ""
stack := ""
for key, element := range record {
if strings.Contains(schema.Columns[key], "char") || strings.Contains(schema.Columns[key], "date") {
if strings.Contains(t.Columns[key], "char") || strings.Contains(t.Columns[key], "date") {
stack = stack + " " + key + " = " + pq.QuoteLiteral(fmt.Sprintf("%v", element)) + ","
@ -227,17 +267,18 @@ func Update(table string, record AssRow) string {
}
}
stack = removeLastChar(stack)
query := ("UPDATE " + table + " SET " + stack + " WHERE id = " + id)
rows, err := db.Query(query)
query := ("UPDATE " + t.Name + " SET " + stack + " WHERE id = " + id)
rows, err := t.db.conn.Query(query)
if err != nil {
log.Println(query)
log.Println(err)
return err
}
defer rows.Close()
return query
return nil
}
func Delete(table string, record AssRow) string {
func (t *TableInfo) Delete(record AssRow) error {
id := ""
values := ""
@ -248,31 +289,32 @@ func Delete(table string, record AssRow) string {
}
}
query := ("DELETE FROM " + table + " WHERE id = " + id)
rows, err := db.Query(query)
query := ("DELETE FROM " + t.Name + " WHERE id = " + id)
rows, err := t.db.conn.Query(query)
if err != nil {
log.Println(query)
log.Println(err)
return err
}
defer rows.Close()
return query
return nil
}
func UpdateOrInsert(table string, record AssRow) int {
id := 0
func (t *TableInfo) UpdateOrInsert(record AssRow) (int, error) {
id := -1
for key, element := range record {
if key == "id" {
sid := fmt.Sprintf("%v", element)
id, _ = strconv.Atoi(sid)
}
}
if id == 0 {
return Insert(table, record)
if id == -1 {
return t.Insert(record)
} else {
Update(table, record)
return id
t.Update(record)
return id, nil
}
}
func removeLastChar(s string) string {

View File

@ -9,8 +9,8 @@ import (
)
func TestCreateTable(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close()
db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer db.Close()
jsonFile, err := os.Open("test_table.json")
if err != nil {
@ -20,25 +20,37 @@ func TestCreateTable(t *testing.T) {
byteValue, _ := ioutil.ReadAll(jsonFile)
var data Table
json.Unmarshal([]byte(byteValue), &data)
var jsonSource TableInfo
json.Unmarshal([]byte(byteValue), &jsonSource)
CreateTable(data)
err = db.CreateTable(jsonSource)
if err != nil {
fmt.Println(err.Error())
}
tbl := GetSchema(data.Name)
if len(tbl.Columns) == 0 {
sch, err := db.Table("test").GetSchema()
if err != nil {
fmt.Println(err.Error())
}
if len(sch.Columns) == 0 {
t.Errorf("Create table failed")
}
}
func TestAddColumn(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close()
db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer db.Close()
old := GetSchema("test")
AddColumn("test", "addcolumn", "integer")
new := GetSchema("test")
old, err := db.Table("test").GetSchema()
if err != nil {
fmt.Println(err.Error())
}
db.Table("test").AddColumn("addcolumn", "integer")
new, err := db.Table("test").GetSchema()
if err != nil {
fmt.Println(err.Error())
}
if len(old.Columns) == len(new.Columns) {
t.Errorf("Column already exist")
@ -47,20 +59,26 @@ func TestAddColumn(t *testing.T) {
func TestInsert(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close()
db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer db.Close()
vl := make(AssRow)
vl["name"] = "toto"
vl["description"] = "tata"
old := GetAssociativeArray("test", []string{"*"}, "", []string{}, "")
old, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "")
if err != nil {
fmt.Println(err.Error())
}
jsonStringOld, _ := json.Marshal(old)
fmt.Println(string(jsonStringOld))
UpdateOrInsert("test", vl)
db.Table("test").UpdateOrInsert(vl)
new := GetAssociativeArray("test", []string{"*"}, "", []string{}, "")
new, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "")
if err != nil {
fmt.Println(err.Error())
}
jsonStringNew, _ := json.Marshal(new)
fmt.Println(string(jsonStringNew))
@ -71,21 +89,27 @@ func TestInsert(t *testing.T) {
func TestUpdate(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close()
db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer db.Close()
vl := make(AssRow)
vl["id"] = 1
vl["name"] = "titi"
vl["description"] = "toto"
old := GetAssociativeArray("test", []string{"*"}, "", []string{}, "")
old, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "")
if err != nil {
fmt.Println(err.Error())
}
jsonStringOld, _ := json.Marshal(old)
fmt.Println(string(jsonStringOld))
UpdateOrInsert("test", vl)
db.Table("test").UpdateOrInsert(vl)
new := GetAssociativeArray("test", []string{"*"}, "", []string{}, "")
new, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "")
if err != nil {
fmt.Println(err.Error())
}
jsonStringNew, _ := json.Marshal(new)
fmt.Println(string(jsonStringNew))
@ -97,19 +121,25 @@ func TestUpdate(t *testing.T) {
func TestDelete(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close()
db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer db.Close()
vl := make(AssRow)
vl["id"] = 1
old := GetAssociativeArray("test", []string{"*"}, "", []string{}, "")
old, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "")
if err != nil {
fmt.Println(err.Error())
}
jsonStringOld, _ := json.Marshal(old)
fmt.Println(string(jsonStringOld))
Delete("test", vl)
db.Table("test").Delete(vl)
new := GetAssociativeArray("test", []string{"*"}, "", []string{}, "")
new, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "")
if err != nil {
fmt.Println(err.Error())
}
jsonStringNew, _ := json.Marshal(new)
fmt.Println(string(jsonStringNew))
@ -120,12 +150,18 @@ func TestDelete(t *testing.T) {
func TestDeleteColumn(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close()
db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer db.Close()
old := GetSchema("test")
DeleteColumn("test", "addcolumn")
new := GetSchema("test")
old, err := db.Table("test").GetSchema()
if err != nil {
fmt.Println(err.Error())
}
db.Table("test").DeleteColumn("addcolumn")
new, err := db.Table("test").GetSchema()
if err != nil {
fmt.Println(err.Error())
}
if len(old.Columns) == len(new.Columns) {
t.Errorf("Error column not deleted")
@ -133,13 +169,15 @@ func TestDeleteColumn(t *testing.T) {
}
func TestDeleteTable(t *testing.T) {
Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer Close()
db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable")
defer db.Close()
DeleteTable("test")
tbl := GetSchema("test")
db.Table("test").DeleteTable()
tbl, err := db.Table("test").GetSchema()
if err != nil {
fmt.Println(err.Error())
}
if len(tbl.Columns) != 0 {
t.Errorf("Delete table failed")
}