From 7f0268386646bd234a83bd2c7baf558e010ba0fe Mon Sep 17 00:00:00 2001 From: ycc Date: Thu, 28 Oct 2021 10:34:10 +0200 Subject: [PATCH] Global refactor --- LICENSE | 21 ------ README.md | 2 - go.mod | 2 +- pg.go | 190 ++++++++++++++++++++++++++++++++--------------------- pg_test.go | 112 ++++++++++++++++++++----------- 5 files changed, 192 insertions(+), 135 deletions(-) delete mode 100644 LICENSE delete mode 100644 README.md diff --git a/LICENSE b/LICENSE deleted file mode 100644 index aa7678a..0000000 --- a/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2021 redr00m - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/README.md b/README.md deleted file mode 100644 index f782e22..0000000 --- a/README.md +++ /dev/null @@ -1,2 +0,0 @@ -# sqldb -Generic SQL access libradry diff --git a/go.mod b/go.mod index ceeb144..03dae42 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/redr00m/sqldb +module forge.redroom.link/yves/sqldb go 1.15 diff --git a/pg.go b/pg.go index 94f63a6..952559e 100755 --- a/pg.go +++ b/pg.go @@ -10,48 +10,72 @@ import ( "github.com/lib/pq" ) -var db *sql.DB +type Db struct { + Driver string + Url string + conn *sql.DB +} // AssRow : associative row type type AssRow map[string]interface{} +// Select Result +type Rows []AssRow + // Table is a table structure description -type Table struct { +type TableInfo struct { Name string `json:"name"` Columns map[string]string `json:"columns"` + db *Db } // Open the database -func Open(driver string, url string) { +func Open(driver string, url string) *Db { + var database Db var err error - db, err = sql.Open(driver, url) + database.Driver = driver + database.Url = url + database.conn, err = sql.Open(driver, url) if err != nil { log.Println(err) } + return &database } // Close the database connection -func Close() { - db.Close() +func (db *Db) Close() { + db.conn.Close() +} + +func (db *Db) Table(name string) *TableInfo { + var ti TableInfo + ti.Name = name + ti.db = db + return &ti } // GetAssociativeArray : Provide results as an associative array -func GetAssociativeArray(table string, columns []string, restriction string, sortkeys []string, dir string) []AssRow { - return QueryAssociativeArray(buildSelect(table, "", columns, restriction, sortkeys, dir)) +func (t *TableInfo) GetAssociativeArray(columns []string, restriction string, sortkeys []string, dir string) ([]AssRow, error) { + return t.db.QueryAssociativeArray(t.buildSelect("", columns, restriction, sortkeys, dir)) } // QueryAssociativeArray : Provide results as an associative array -func QueryAssociativeArray(query string) []AssRow { - rows, err := db.Query(query) +func (db *Db) QueryAssociativeArray(query string) (Rows, error) { + rows, err := db.conn.Query(query) if err != nil { log.Println(err) log.Println(query) + return nil, err } defer rows.Close() - var results []AssRow - cols, _ := rows.Columns() - + var results Rows + cols, err := rows.Columns() + if err != nil { + log.Println(err) + log.Println(query) + return nil, err + } for rows.Next() { // Create a slice of interface{}'s to represent each column, // and a second slice to contain pointers to each item in the columns slice. @@ -70,23 +94,26 @@ func QueryAssociativeArray(query string) []AssRow { m := make(AssRow) for i, colName := range cols { val := columnPointers[i].(*interface{}) - m[colName] = *val + m[colName] = fmt.Sprintf("%v", *val) } - // jsonString, _ := json.Marshal(m) - // Outputs: map[columnName:value columnName2:value2 columnName3:value3 ...] - // fmt.Println(string(jsonString)) results = append(results, m) } - return results + return results, nil } // GetSchema : Provide results as an associative array -func GetSchema(table string) Table { - t := Table{Name: table} - cols := QueryAssociativeArray("SELECT column_name :: varchar as name, REPLACE(REPLACE(data_type,'character varying','varchar'),'character','char') || COALESCE('(' || character_maximum_length || ')', '') as type from INFORMATION_SCHEMA.COLUMNS where table_name ='" + table + "';") - t.Columns = make(map[string]string) +func (t *TableInfo) GetSchema() (*TableInfo, error) { + var ti TableInfo + ti.Name = t.Name + ti.db = t.db + cols, err := t.db.QueryAssociativeArray("SELECT column_name :: varchar as name, REPLACE(REPLACE(data_type,'character varying','varchar'),'character','char') || COALESCE('(' || character_maximum_length || ')', '') as type from INFORMATION_SCHEMA.COLUMNS where table_name ='" + t.Name + "';") + if err != nil { + log.Println(err) + return nil, err + } + ti.Columns = make(map[string]string) for _, row := range cols { var name, rowtype string for key, element := range row { @@ -97,16 +124,17 @@ func GetSchema(table string) Table { rowtype = fmt.Sprintf("%v", element) } } - t.Columns[name] = rowtype + ti.Columns[name] = rowtype } - return t + return &ti, nil } -func ListTables() []AssRow { - return QueryAssociativeArray("SELECT table_name :: varchar FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name;") +func (db *Db) ListTables() (Rows, error) { + return db.QueryAssociativeArray("SELECT table_name :: varchar FROM information_schema.tables WHERE table_schema = 'public' ORDER BY table_name;") } -func CreateTable(t Table) { +func (db *Db) CreateTable(t TableInfo) error { + t.db = db query := "create table " + t.Name + " ( " columns := "" for name, rowtype := range t.Columns { @@ -120,57 +148,61 @@ func CreateTable(t Table) { } query += columns query = query[:len(query)-1] + " )" - _, err := db.Query(query) + _, err := t.db.conn.Query(query) if err != nil { - log.Println(query) - } - query = "create sequence if not exists sq_" + t.Name - _, err = db.Query(query) - if err != nil { - log.Println(query) + log.Println(err.Error()) + return err } + return nil } -func DeleteTable(table string) { - query := "drop table " + table - _, err := db.Query(query) +func (t *TableInfo) DeleteTable() error { + query := "drop table " + t.Name + _, err := t.db.conn.Query(query) if err != nil { - log.Println(err) + log.Println(err.Error()) + return err } - query = "drop sequence if exists sq_" + table - _, err = db.Query(query) + query = "drop sequence if exists sq_" + t.Name + _, err = t.db.conn.Query(query) if err != nil { - log.Println(err) + log.Println(err.Error()) + return err } + return nil } -func AddColumn(table string, name string, sqltype string) { - query := "alter table " + table + " add " + name + " " + sqltype - rows, err := db.Query(query) +func (t *TableInfo) AddColumn(name string, sqltype string) error { + query := "alter table " + t.Name + " add " + name + " " + sqltype + rows, err := t.db.conn.Query(query) if err != nil { log.Println(err) + return err } defer rows.Close() + return nil } -func DeleteColumn(table string, name string) { - query := "alter table " + table + " drop " + name - rows, err := db.Query(query) +func (t *TableInfo) DeleteColumn(name string) error { + query := "alter table " + t.Name + " drop " + name + rows, err := t.db.conn.Query(query) if err != nil { log.Println(err) + return err } defer rows.Close() + return nil } -func ListSequences() []AssRow { - return QueryAssociativeArray("SELECT sequence_name :: varchar FROM information_schema.sequences WHERE sequence_schema = 'public' ORDER BY sequence_name;") +func (db *Db) ListSequences() (Rows, error) { + return db.QueryAssociativeArray("SELECT sequence_name :: varchar FROM information_schema.sequences WHERE sequence_schema = 'public' ORDER BY sequence_name;") } -func buildSelect(table string, key string, columns []string, restriction string, sortkeys []string, dir ...string) string { +func (t *TableInfo) buildSelect(key string, columns []string, restriction string, sortkeys []string, dir ...string) string { if key != "" { columns = append(columns, key) } - query := "select " + strings.Join(columns, ",") + " from " + table + query := "select " + strings.Join(columns, ",") + " from " + t.Name if restriction != "" { query += " where " + restriction } @@ -183,17 +215,21 @@ func buildSelect(table string, key string, columns []string, restriction string, return query } -func Insert(table string, record AssRow) int { +func (t *TableInfo) Insert(record AssRow) (int, error) { columns := "" values := "" - schema := GetSchema(table) + t, err := t.GetSchema() + if err != nil { + log.Println(err) + return -1, err + } var id int for key, element := range record { - if strings.Contains(schema.Columns[key], "char") || strings.Contains(schema.Columns[key], "date") { + if strings.Contains(t.Columns[key], "char") || strings.Contains(t.Columns[key], "date") { columns += key + "," - values += fmt.Sprintf(pq.QuoteLiteral(fmt.Sprintf("%v", element))) + "," + values += fmt.Sprint(pq.QuoteLiteral(fmt.Sprintf("%v", element))) + "," } else { columns += key + "," @@ -201,19 +237,23 @@ func Insert(table string, record AssRow) int { } } - db.QueryRow("INSERT INTO " + table + "(" + removeLastChar(columns) + ") VALUES (" + removeLastChar(values) + ") RETURNING id").Scan(&id) - return id + t.db.conn.QueryRow("INSERT INTO " + t.Name + "(" + removeLastChar(columns) + ") VALUES (" + removeLastChar(values) + ") RETURNING id").Scan(&id) + return id, nil } -func Update(table string, record AssRow) string { +func (t *TableInfo) Update(record AssRow) error { - schema := GetSchema(table) + t, err := t.GetSchema() + if err != nil { + log.Println(err) + return err + } id := "" stack := "" for key, element := range record { - if strings.Contains(schema.Columns[key], "char") || strings.Contains(schema.Columns[key], "date") { + if strings.Contains(t.Columns[key], "char") || strings.Contains(t.Columns[key], "date") { stack = stack + " " + key + " = " + pq.QuoteLiteral(fmt.Sprintf("%v", element)) + "," @@ -227,17 +267,18 @@ func Update(table string, record AssRow) string { } } stack = removeLastChar(stack) - query := ("UPDATE " + table + " SET " + stack + " WHERE id = " + id) - rows, err := db.Query(query) + query := ("UPDATE " + t.Name + " SET " + stack + " WHERE id = " + id) + rows, err := t.db.conn.Query(query) if err != nil { log.Println(query) log.Println(err) + return err } defer rows.Close() - return query + return nil } -func Delete(table string, record AssRow) string { +func (t *TableInfo) Delete(record AssRow) error { id := "" values := "" @@ -248,31 +289,32 @@ func Delete(table string, record AssRow) string { } } - query := ("DELETE FROM " + table + " WHERE id = " + id) - rows, err := db.Query(query) + query := ("DELETE FROM " + t.Name + " WHERE id = " + id) + rows, err := t.db.conn.Query(query) if err != nil { log.Println(query) log.Println(err) + return err } defer rows.Close() - return query + return nil } -func UpdateOrInsert(table string, record AssRow) int { - id := 0 - +func (t *TableInfo) UpdateOrInsert(record AssRow) (int, error) { + id := -1 for key, element := range record { if key == "id" { sid := fmt.Sprintf("%v", element) id, _ = strconv.Atoi(sid) } } - if id == 0 { - return Insert(table, record) + if id == -1 { + return t.Insert(record) } else { - Update(table, record) - return id + t.Update(record) + return id, nil } + } func removeLastChar(s string) string { diff --git a/pg_test.go b/pg_test.go index 4112b6f..32475c4 100755 --- a/pg_test.go +++ b/pg_test.go @@ -9,8 +9,8 @@ import ( ) func TestCreateTable(t *testing.T) { - Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") - defer Close() + db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") + defer db.Close() jsonFile, err := os.Open("test_table.json") if err != nil { @@ -20,25 +20,37 @@ func TestCreateTable(t *testing.T) { byteValue, _ := ioutil.ReadAll(jsonFile) - var data Table - json.Unmarshal([]byte(byteValue), &data) + var jsonSource TableInfo + json.Unmarshal([]byte(byteValue), &jsonSource) - CreateTable(data) + err = db.CreateTable(jsonSource) + if err != nil { + fmt.Println(err.Error()) + } - tbl := GetSchema(data.Name) - if len(tbl.Columns) == 0 { + sch, err := db.Table("test").GetSchema() + if err != nil { + fmt.Println(err.Error()) + } + if len(sch.Columns) == 0 { t.Errorf("Create table failed") } } func TestAddColumn(t *testing.T) { - Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") - defer Close() + db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") + defer db.Close() - old := GetSchema("test") - AddColumn("test", "addcolumn", "integer") - new := GetSchema("test") + old, err := db.Table("test").GetSchema() + if err != nil { + fmt.Println(err.Error()) + } + db.Table("test").AddColumn("addcolumn", "integer") + new, err := db.Table("test").GetSchema() + if err != nil { + fmt.Println(err.Error()) + } if len(old.Columns) == len(new.Columns) { t.Errorf("Column already exist") @@ -47,20 +59,26 @@ func TestAddColumn(t *testing.T) { func TestInsert(t *testing.T) { - Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") - defer Close() + db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") + defer db.Close() vl := make(AssRow) vl["name"] = "toto" vl["description"] = "tata" - old := GetAssociativeArray("test", []string{"*"}, "", []string{}, "") + old, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "") + if err != nil { + fmt.Println(err.Error()) + } jsonStringOld, _ := json.Marshal(old) fmt.Println(string(jsonStringOld)) - UpdateOrInsert("test", vl) + db.Table("test").UpdateOrInsert(vl) - new := GetAssociativeArray("test", []string{"*"}, "", []string{}, "") + new, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "") + if err != nil { + fmt.Println(err.Error()) + } jsonStringNew, _ := json.Marshal(new) fmt.Println(string(jsonStringNew)) @@ -71,21 +89,27 @@ func TestInsert(t *testing.T) { func TestUpdate(t *testing.T) { - Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") - defer Close() + db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") + defer db.Close() vl := make(AssRow) vl["id"] = 1 vl["name"] = "titi" vl["description"] = "toto" - old := GetAssociativeArray("test", []string{"*"}, "", []string{}, "") + old, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "") + if err != nil { + fmt.Println(err.Error()) + } jsonStringOld, _ := json.Marshal(old) fmt.Println(string(jsonStringOld)) - UpdateOrInsert("test", vl) + db.Table("test").UpdateOrInsert(vl) - new := GetAssociativeArray("test", []string{"*"}, "", []string{}, "") + new, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "") + if err != nil { + fmt.Println(err.Error()) + } jsonStringNew, _ := json.Marshal(new) fmt.Println(string(jsonStringNew)) @@ -97,19 +121,25 @@ func TestUpdate(t *testing.T) { func TestDelete(t *testing.T) { - Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") - defer Close() + db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") + defer db.Close() vl := make(AssRow) vl["id"] = 1 - old := GetAssociativeArray("test", []string{"*"}, "", []string{}, "") + old, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "") + if err != nil { + fmt.Println(err.Error()) + } jsonStringOld, _ := json.Marshal(old) fmt.Println(string(jsonStringOld)) - Delete("test", vl) + db.Table("test").Delete(vl) - new := GetAssociativeArray("test", []string{"*"}, "", []string{}, "") + new, err := db.Table("test").GetAssociativeArray([]string{"*"}, "", []string{}, "") + if err != nil { + fmt.Println(err.Error()) + } jsonStringNew, _ := json.Marshal(new) fmt.Println(string(jsonStringNew)) @@ -120,12 +150,18 @@ func TestDelete(t *testing.T) { func TestDeleteColumn(t *testing.T) { - Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") - defer Close() + db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") + defer db.Close() - old := GetSchema("test") - DeleteColumn("test", "addcolumn") - new := GetSchema("test") + old, err := db.Table("test").GetSchema() + if err != nil { + fmt.Println(err.Error()) + } + db.Table("test").DeleteColumn("addcolumn") + new, err := db.Table("test").GetSchema() + if err != nil { + fmt.Println(err.Error()) + } if len(old.Columns) == len(new.Columns) { t.Errorf("Error column not deleted") @@ -133,13 +169,15 @@ func TestDeleteColumn(t *testing.T) { } func TestDeleteTable(t *testing.T) { - Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") - defer Close() + db := Open("postgres", "host=127.0.0.1 port=5432 user=test password=test dbname=test sslmode=disable") + defer db.Close() - DeleteTable("test") - - tbl := GetSchema("test") + db.Table("test").DeleteTable() + tbl, err := db.Table("test").GetSchema() + if err != nil { + fmt.Println(err.Error()) + } if len(tbl.Columns) != 0 { t.Errorf("Delete table failed") }