commit
6a95ef0007
15
.editorconfig
Normal file
15
.editorconfig
Normal file
@ -0,0 +1,15 @@
|
||||
root = true
|
||||
|
||||
[*]
|
||||
indent_style = tab
|
||||
indent_size = 4
|
||||
end_of_line = lf
|
||||
charset = utf-8
|
||||
trim_trailing_whitespace = true
|
||||
insert_final_newline = true
|
||||
|
||||
[*.py]
|
||||
max_line_length = 99
|
||||
|
||||
[*.{yaml,yml,py}]
|
||||
indent_style = space
|
20
.gitignore
vendored
20
.gitignore
vendored
@ -1,19 +1,15 @@
|
||||
.idea
|
||||
.vscode
|
||||
build/
|
||||
dist/
|
||||
*.egg-info
|
||||
|
||||
.venv
|
||||
pip-selfcheck.json
|
||||
*.pyc
|
||||
__pycache__
|
||||
|
||||
*.db
|
||||
*.yaml
|
||||
!example-config.yaml
|
||||
|
||||
logs/
|
||||
|
||||
plugins/
|
||||
|
||||
# Bots under maubot.xyz
|
||||
jesaribot/
|
||||
sed/
|
||||
github/
|
||||
gitlab/
|
||||
rss/
|
||||
factorial/
|
||||
dictionary/
|
||||
|
32
Dockerfile
32
Dockerfile
@ -1,20 +1,20 @@
|
||||
FROM golang:1-alpine AS builder
|
||||
FROM docker.io/alpine:3.8
|
||||
|
||||
RUN apk add --no-cache git ca-certificates gcc musl-dev
|
||||
RUN wget -qO /usr/local/bin/dep https://github.com/golang/dep/releases/download/v0.5.0/dep-linux-amd64
|
||||
RUN chmod +x /usr/local/bin/dep
|
||||
ENV UID=1338 \
|
||||
GID=1338
|
||||
|
||||
COPY Gopkg.lock Gopkg.toml /go/src/maubot.xyz/
|
||||
WORKDIR /go/src/maubot.xyz/
|
||||
RUN dep ensure -vendor-only
|
||||
COPY . /opt/maubot
|
||||
WORKDIR /opt/maubot
|
||||
RUN apk add --no-cache \
|
||||
python3-dev \
|
||||
build-base \
|
||||
py3-aiohttp \
|
||||
py3-sqlalchemy \
|
||||
py3-attrs \
|
||||
ca-certificates \
|
||||
su-exec \
|
||||
&& pip3 install -r requirements.txt -r optional-requirements.txt
|
||||
|
||||
COPY . /go/src/maubot.xyz/
|
||||
RUN go build -o /usr/bin/maubot maubot.xyz/cmd/maubot
|
||||
VOLUME /data
|
||||
|
||||
|
||||
FROM alpine
|
||||
|
||||
RUN apk add --no-cache ca-certificates
|
||||
COPY --from=builder /usr/bin/maubot /usr/bin/maubot
|
||||
|
||||
CMD ["/usr/bin/maubot", "-c", "/etc/maubot/config.yaml"]
|
||||
CMD ["/opt/mautrix-telegram/docker-run.sh"]
|
||||
|
75
Gopkg.lock
generated
75
Gopkg.lock
generated
@ -1,75 +0,0 @@
|
||||
# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'.
|
||||
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/gorilla/context"
|
||||
packages = ["."]
|
||||
revision = "08b5f424b9271eedf6f9f0ce86cb9396ed337a42"
|
||||
version = "v1.1.1"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/gorilla/mux"
|
||||
packages = ["."]
|
||||
revision = "e3702bed27f0d39777b0b37b664b6280e8ef8fbf"
|
||||
version = "v1.6.2"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/mattn/go-sqlite3"
|
||||
packages = ["."]
|
||||
revision = "25ecb14adfc7543176f7d85291ec7dba82c6f7e4"
|
||||
version = "v1.9.0"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "github.com/shurcooL/sanitized_anchor_name"
|
||||
packages = ["."]
|
||||
revision = "86672fcb3f950f35f2e675df2240550f2a50762f"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/net"
|
||||
packages = [
|
||||
"html",
|
||||
"html/atom"
|
||||
]
|
||||
revision = "26e67e76b6c3f6ce91f7c52def5af501b4e0f3a2"
|
||||
|
||||
[[projects]]
|
||||
name = "gopkg.in/russross/blackfriday.v2"
|
||||
packages = ["."]
|
||||
revision = "cadec560ec52d93835bf2f15bd794700d3a2473b"
|
||||
version = "v2.0.0"
|
||||
|
||||
[[projects]]
|
||||
name = "gopkg.in/yaml.v2"
|
||||
packages = ["."]
|
||||
revision = "5420a8b6744d3b0345ab293f6fcba19c978f1183"
|
||||
version = "v2.2.1"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "maunium.net/go/gomatrix"
|
||||
packages = [
|
||||
".",
|
||||
"format"
|
||||
]
|
||||
revision = "920b154a410aeb5a55200d7b21363732abff3502"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "maunium.net/go/mauflag"
|
||||
packages = ["."]
|
||||
revision = "8337821952ba5e919673bd62c502d43474e5e71d"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "maunium.net/go/maulogger"
|
||||
packages = ["."]
|
||||
revision = "ed98745dedb5f9296c1b2a0ed9424d7347d7e7d4"
|
||||
|
||||
[solve-meta]
|
||||
analyzer-name = "dep"
|
||||
analyzer-version = 1
|
||||
inputs-digest = "6b56fff780b66591381a1d1c4572951bbad3deea30e0774d34721d89adeee379"
|
||||
solver-name = "gps-cdcl"
|
||||
solver-version = 1
|
58
Gopkg.toml
58
Gopkg.toml
@ -1,58 +0,0 @@
|
||||
# Gopkg.toml example
|
||||
#
|
||||
# Refer to https://golang.github.io/dep/docs/Gopkg.toml.html
|
||||
# for detailed Gopkg.toml documentation.
|
||||
#
|
||||
# required = ["github.com/user/thing/cmd/thing"]
|
||||
# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"]
|
||||
#
|
||||
# [[constraint]]
|
||||
# name = "github.com/user/project"
|
||||
# version = "1.0.0"
|
||||
#
|
||||
# [[constraint]]
|
||||
# name = "github.com/user/project2"
|
||||
# branch = "dev"
|
||||
# source = "github.com/myfork/project2"
|
||||
#
|
||||
# [[override]]
|
||||
# name = "github.com/x/y"
|
||||
# version = "2.4.0"
|
||||
#
|
||||
# [prune]
|
||||
# non-go = false
|
||||
# go-tests = true
|
||||
# unused-packages = true
|
||||
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/gorilla/mux"
|
||||
version = "1.6.2"
|
||||
|
||||
[[constraint]]
|
||||
name = "github.com/mattn/go-sqlite3"
|
||||
version = "1.9.0"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/net"
|
||||
|
||||
[[constraint]]
|
||||
name = "gopkg.in/yaml.v2"
|
||||
version = "2.2.1"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "maunium.net/go/gomatrix"
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "maunium.net/go/mauflag"
|
||||
|
||||
[[constraint]]
|
||||
name = "maunium.net/go/maulogger"
|
||||
branch = "master"
|
||||
|
||||
[prune]
|
||||
go-tests = true
|
||||
unused-packages = true
|
@ -1,5 +1,5 @@
|
||||
# maubot
|
||||
A plugin-based [Matrix](https://matrix.org) bot system written in Go.
|
||||
A plugin-based [Matrix](https://matrix.org) bot system written in Python.
|
||||
|
||||
Work in progress. Please come back later.
|
||||
|
||||
@ -9,3 +9,5 @@ Matrix room: [#maubot:maunium.net](https://matrix.to/#/#maubot:maunium.net)
|
||||
## Plugins
|
||||
* [jesaribot](https://github.com/maubot/jesaribot) - A simple bot that replies with an image when you say "jesari".
|
||||
* [sed](https://github.com/maubot/sed) - A bot to do sed-like replacements.
|
||||
* [factorial](https://github.com/maubot/factorial) - A bot to calculate unexpected factorials.
|
||||
* [dictionary](https://github.com/maubot/dictionary) - A bot that provides dictionary definitions for words.
|
||||
|
77
app/bot.go
77
app/bot.go
@ -1,77 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package app
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"maubot.xyz"
|
||||
"maubot.xyz/config"
|
||||
"maubot.xyz/database"
|
||||
"maubot.xyz/matrix"
|
||||
log "maunium.net/go/maulogger"
|
||||
)
|
||||
|
||||
type Bot struct {
|
||||
Config *config.MainConfig
|
||||
Database *database.Database
|
||||
Clients map[string]*matrix.Client
|
||||
PluginCreators map[string]*maubot.PluginCreator
|
||||
Plugins map[string]*PluginWrapper
|
||||
Server *http.Server
|
||||
}
|
||||
|
||||
func New(config *config.MainConfig) *Bot {
|
||||
return &Bot{
|
||||
Config: config,
|
||||
Clients: make(map[string]*matrix.Client),
|
||||
Plugins: make(map[string]*PluginWrapper),
|
||||
PluginCreators: make(map[string]*maubot.PluginCreator),
|
||||
}
|
||||
}
|
||||
|
||||
func (bot *Bot) Init() {
|
||||
bot.initDatabase()
|
||||
bot.initClients()
|
||||
bot.initServer()
|
||||
bot.loadPlugins()
|
||||
bot.createPlugins()
|
||||
}
|
||||
|
||||
func (bot *Bot) Start() {
|
||||
go bot.startClients()
|
||||
go bot.startServer()
|
||||
bot.startPlugins()
|
||||
}
|
||||
|
||||
func (bot *Bot) Stop() {
|
||||
bot.stopPlugins()
|
||||
bot.stopServer()
|
||||
bot.stopClients()
|
||||
}
|
||||
|
||||
func (bot *Bot) initDatabase() {
|
||||
log.Debugln("Initializing database")
|
||||
bot.Database = &bot.Config.Database
|
||||
err := bot.Database.Connect()
|
||||
if err != nil {
|
||||
log.Fatalln("Failed to connect to database:", err)
|
||||
os.Exit(2)
|
||||
}
|
||||
bot.Database.CreateTables()
|
||||
}
|
59
app/http.go
59
app/http.go
@ -1,59 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "maunium.net/go/maulogger"
|
||||
)
|
||||
|
||||
func (bot *Bot) initServer() {
|
||||
log.Debugln("Initializing HTTP server")
|
||||
r := mux.NewRouter()
|
||||
http.Handle(bot.Config.Server.BasePath, r)
|
||||
bot.Server = &http.Server{
|
||||
Addr: bot.Config.Server.Listen,
|
||||
WriteTimeout: time.Second * 15,
|
||||
ReadTimeout: time.Second * 15,
|
||||
IdleTimeout: time.Second * 60,
|
||||
Handler: r,
|
||||
}
|
||||
}
|
||||
|
||||
func (bot *Bot) startServer() {
|
||||
log.Debugf("Listening at http://%s%s\n", bot.Server.Addr, bot.Config.Server.BasePath)
|
||||
if err := bot.Server.ListenAndServe(); err != nil {
|
||||
log.Fatalln("HTTP server errored:", err)
|
||||
bot.Server = nil
|
||||
bot.Stop()
|
||||
os.Exit(10)
|
||||
}
|
||||
}
|
||||
|
||||
func (bot *Bot) stopServer() {
|
||||
if bot.Server != nil {
|
||||
log.Debugln("Stopping HTTP server")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
bot.Server.Shutdown(ctx)
|
||||
}
|
||||
}
|
@ -1,58 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package app
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"maubot.xyz/matrix"
|
||||
log "maunium.net/go/maulogger"
|
||||
)
|
||||
|
||||
func (bot *Bot) initClients() {
|
||||
log.Debugln("Initializing Matrix clients")
|
||||
clients := bot.Database.MatrixClient.GetAll()
|
||||
for _, client := range clients {
|
||||
mxClient, err := matrix.NewClient(client)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to create client to %s as %s: %v\n", client.Homeserver, client.UserID, err)
|
||||
os.Exit(3)
|
||||
}
|
||||
log.Debugln("Initialized user", client.UserID, "with homeserver", client.Homeserver)
|
||||
bot.Clients[client.UserID] = mxClient
|
||||
}
|
||||
}
|
||||
|
||||
func (bot *Bot) startClients() {
|
||||
log.Debugln("Starting Matrix syncer")
|
||||
for _, client := range bot.Clients {
|
||||
go func() {
|
||||
client.SetAvatarURL(client.DB.AvatarURL)
|
||||
client.SetDisplayName(client.DB.DisplayName)
|
||||
}()
|
||||
if client.DB.Sync {
|
||||
client.Sync()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (bot *Bot) stopClients() {
|
||||
log.Debugln("Stopping Matrix syncers")
|
||||
for _, client := range bot.Clients {
|
||||
client.StopSync()
|
||||
}
|
||||
}
|
@ -1,81 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"plugin"
|
||||
|
||||
"maubot.xyz"
|
||||
"maubot.xyz/database"
|
||||
)
|
||||
|
||||
type PluginWrapper struct {
|
||||
maubot.Plugin
|
||||
Creator *maubot.PluginCreator
|
||||
DB *database.Plugin
|
||||
}
|
||||
|
||||
func LoadPlugin(path string) (*maubot.PluginCreator, error) {
|
||||
rawPlugin, err := plugin.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open: %v", err)
|
||||
}
|
||||
|
||||
pluginCreatorSymbol, err := rawPlugin.Lookup("Plugin")
|
||||
if err == nil {
|
||||
pluginCreator, ok := pluginCreatorSymbol.(*maubot.PluginCreator)
|
||||
if ok {
|
||||
pluginCreator.Path = path
|
||||
return pluginCreator, nil
|
||||
}
|
||||
}
|
||||
|
||||
pluginCreatorFuncSymbol, err := rawPlugin.Lookup("Create")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("symbol \"Create\" not found: %v", err)
|
||||
}
|
||||
pluginCreatorFunc, ok := pluginCreatorFuncSymbol.(maubot.PluginCreatorFunc)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("symbol \"Create\" does not implement maubot.PluginCreator")
|
||||
}
|
||||
|
||||
nameSymbol, err := rawPlugin.Lookup("Name")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("symbol \"Name\" not found: %v", err)
|
||||
}
|
||||
name, ok := nameSymbol.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("symbol \"Name\" is not a string")
|
||||
}
|
||||
|
||||
versionSymbol, err := rawPlugin.Lookup("Version")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("symbol \"Version\" not found: %v", err)
|
||||
}
|
||||
version, ok := versionSymbol.(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("symbol \"Version\" is not a string")
|
||||
}
|
||||
|
||||
return &maubot.PluginCreator{
|
||||
Create: pluginCreatorFunc,
|
||||
Name: name,
|
||||
Version: version,
|
||||
Path: path,
|
||||
}, nil
|
||||
}
|
112
app/plugins.go
112
app/plugins.go
@ -1,112 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package app
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
log "maunium.net/go/maulogger"
|
||||
)
|
||||
|
||||
func (bot *Bot) loadPlugin(dir, fileName string) {
|
||||
ext := fileName[len(fileName)-4:]
|
||||
if ext != ".mbp" {
|
||||
return
|
||||
}
|
||||
|
||||
path := filepath.Join(dir, fileName)
|
||||
|
||||
pluginCreator, err := LoadPlugin(path)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load plugin at %s: %v\n", path, err)
|
||||
os.Exit(4)
|
||||
}
|
||||
|
||||
_, exists := bot.PluginCreators[pluginCreator.Name]
|
||||
if exists {
|
||||
log.Debugf("Skipping plugin at %s: plugin with same name already loaded", path)
|
||||
return
|
||||
}
|
||||
|
||||
bot.PluginCreators[pluginCreator.Name] = pluginCreator
|
||||
log.Debugf("Loaded plugin creator %s v%s\n", pluginCreator.Name, pluginCreator.Version)
|
||||
}
|
||||
|
||||
func (bot *Bot) loadPlugins() {
|
||||
for _, dir := range bot.Config.PluginDirs {
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to read plugin directory %s: %v\n", dir, err)
|
||||
os.Exit(4)
|
||||
}
|
||||
for _, file := range files {
|
||||
bot.loadPlugin(dir, file.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (bot *Bot) createPlugins() {
|
||||
log.Debugln("Creating plugin instances")
|
||||
plugins := bot.Database.Plugin.GetAll()
|
||||
for _, plugin := range plugins {
|
||||
if !plugin.Enabled {
|
||||
log.Debugln("Skipping disabled plugin", plugin.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
creator, ok := bot.PluginCreators[plugin.Type]
|
||||
if !ok {
|
||||
log.Errorln("Plugin creator", plugin.Type, "for", plugin.ID, "not found, disabling plugin...")
|
||||
plugin.Enabled = false
|
||||
plugin.Update()
|
||||
continue
|
||||
}
|
||||
|
||||
client, ok := bot.Clients[plugin.UserID]
|
||||
if !ok {
|
||||
log.Errorln("Client", plugin.UserID, "for", plugin.ID, "not found, disabling plugin...")
|
||||
plugin.Enabled = false
|
||||
plugin.Update()
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("Created plugin %s (type %s v%s)\n", plugin.ID, creator.Name, creator.Version)
|
||||
bot.Plugins[plugin.ID] = &PluginWrapper{
|
||||
Plugin: creator.Create(client.Proxy(plugin.ID), log.Sub(plugin.ID)),
|
||||
Creator: creator,
|
||||
DB: plugin,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (bot *Bot) startPlugins() {
|
||||
log.Debugln("Starting plugin instances...")
|
||||
for _, plugin := range bot.Plugins {
|
||||
log.Debugf("Starting plugin %s (type %s v%s)\n", plugin.DB.ID, plugin.Creator.Name, plugin.Creator.Version)
|
||||
go plugin.Start()
|
||||
}
|
||||
}
|
||||
|
||||
func (bot *Bot) stopPlugins() {
|
||||
log.Debugln("Stopping plugin instances...")
|
||||
for _, plugin := range bot.Plugins {
|
||||
log.Debugf("Stopping plugin %s (type %s v%s)\n", plugin.DB.ID, plugin.Creator.Name, plugin.Creator.Version)
|
||||
plugin.Stop()
|
||||
}
|
||||
}
|
@ -1,70 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"maubot.xyz/app"
|
||||
"maubot.xyz/config"
|
||||
flag "maunium.net/go/mauflag"
|
||||
log "maunium.net/go/maulogger"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.SetHelpTitles("maubot - A plugin-based Matrix bot system written in Go.", "maubot [-c /path/to/config] [-h]")
|
||||
configPath := flag.MakeFull("c", "config", "The path to the main config file", "maubot.yaml").String()
|
||||
wantHelp, _ := flag.MakeHelpFlag()
|
||||
|
||||
err := flag.Parse()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
flag.PrintHelp()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if *wantHelp {
|
||||
flag.PrintHelp()
|
||||
return
|
||||
}
|
||||
|
||||
cfg := &config.MainConfig{}
|
||||
err = cfg.Load(*configPath)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to load config:", err)
|
||||
return
|
||||
}
|
||||
cfg.Logging.Configure(log.DefaultLogger)
|
||||
log.OpenFile()
|
||||
log.Debugln("Logger configured")
|
||||
|
||||
bot := app.New(cfg)
|
||||
bot.Init()
|
||||
bot.Start()
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
<-c
|
||||
log.Debugln("Interrupt received, stopping components...")
|
||||
bot.Stop()
|
||||
log.Debugln("Components stopped, bye!")
|
||||
os.Exit(0)
|
||||
}
|
168
commands.go
168
commands.go
@ -1,168 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package maubot
|
||||
|
||||
type CommandHandler func(*Event) CommandHandlerResult
|
||||
|
||||
type CommandSpec struct {
|
||||
Commands []Command `json:"commands,omitempty"`
|
||||
PassiveCommands []PassiveCommand `json:"passive_commands,omitempty"`
|
||||
}
|
||||
|
||||
func (spec *CommandSpec) Clone() *CommandSpec {
|
||||
return &CommandSpec{
|
||||
Commands: append([]Command(nil), spec.Commands...),
|
||||
PassiveCommands: append([]PassiveCommand(nil), spec.PassiveCommands...),
|
||||
}
|
||||
}
|
||||
|
||||
func (spec *CommandSpec) Merge(otherSpecs ...*CommandSpec) {
|
||||
for _, otherSpec := range otherSpecs {
|
||||
spec.Commands = append(spec.Commands, otherSpec.Commands...)
|
||||
spec.PassiveCommands = append(spec.PassiveCommands, otherSpec.PassiveCommands...)
|
||||
}
|
||||
}
|
||||
|
||||
func (spec *CommandSpec) Equals(otherSpec *CommandSpec) bool {
|
||||
if otherSpec == nil ||
|
||||
len(spec.Commands) != len(otherSpec.Commands) ||
|
||||
len(spec.PassiveCommands) != len(otherSpec.PassiveCommands) {
|
||||
return false
|
||||
}
|
||||
|
||||
for index, cmd := range spec.Commands {
|
||||
otherCmd := otherSpec.Commands[index]
|
||||
if !cmd.Equals(otherCmd) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
for index, cmd := range spec.PassiveCommands {
|
||||
otherCmd := otherSpec.PassiveCommands[index]
|
||||
if !cmd.Equals(otherCmd) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
type Command struct {
|
||||
Syntax string `json:"syntax"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Arguments ArgumentMap `json:"arguments"`
|
||||
}
|
||||
|
||||
func (cmd Command) Equals(otherCmd Command) bool {
|
||||
return cmd.Syntax == otherCmd.Syntax &&
|
||||
cmd.Description == otherCmd.Description &&
|
||||
cmd.Arguments.Equals(otherCmd.Arguments)
|
||||
}
|
||||
|
||||
type ArgumentMap map[string]Argument
|
||||
|
||||
func (argMap ArgumentMap) Equals(otherMap ArgumentMap) bool {
|
||||
if len(argMap) != len(otherMap) {
|
||||
return false
|
||||
}
|
||||
|
||||
for name, argument := range argMap {
|
||||
otherArgument, ok := otherMap[name]
|
||||
if !ok || !argument.Equals(otherArgument) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type Argument struct {
|
||||
Matches string `json:"matches"`
|
||||
Required bool `json:"required"`
|
||||
Description string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
func (arg Argument) Equals(otherArg Argument) bool {
|
||||
return arg.Matches == otherArg.Matches &&
|
||||
arg.Required == otherArg.Required &&
|
||||
arg.Description == otherArg.Description
|
||||
}
|
||||
|
||||
// Common PassiveCommand MatchAgainst targets.
|
||||
const (
|
||||
MatchAgainstBody = "body"
|
||||
)
|
||||
|
||||
// JSONLeftEquals checks if the given JSON-parsed interfaces are equal.
|
||||
// Extra properties in the right interface are ignored.
|
||||
func JSONLeftEquals(left, right interface{}) bool {
|
||||
switch val := left.(type) {
|
||||
case nil:
|
||||
return right == nil
|
||||
case bool:
|
||||
rightVal, ok := right.(bool)
|
||||
return ok && rightVal
|
||||
case float64:
|
||||
rightVal, ok := right.(float64)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return val == rightVal
|
||||
case string:
|
||||
rightVal, ok := right.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return val == rightVal
|
||||
case []interface{}:
|
||||
rightVal, ok := right.([]interface{})
|
||||
if !ok || len(val) != len(rightVal) {
|
||||
return false
|
||||
}
|
||||
for index, leftChild := range val {
|
||||
rightChild := rightVal[index]
|
||||
if !JSONLeftEquals(leftChild, rightChild) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
case map[string]interface{}:
|
||||
rightVal, ok := right.(map[string]interface{})
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for key, leftChild := range val {
|
||||
rightChild, ok := rightVal[key]
|
||||
if !ok || !JSONLeftEquals(leftChild, rightChild) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type PassiveCommand struct {
|
||||
Name string `json:"name"`
|
||||
Matches string `json:"matches"`
|
||||
MatchAgainst string `json:"match_against"`
|
||||
MatchEvent interface{} `json:"match_event"`
|
||||
}
|
||||
|
||||
func (cmd PassiveCommand) Equals(otherCmd PassiveCommand) bool {
|
||||
return cmd.Name == otherCmd.Name &&
|
||||
cmd.Matches == otherCmd.Matches &&
|
||||
cmd.MatchAgainst == otherCmd.MatchAgainst &&
|
||||
(cmd.MatchEvent != nil && JSONLeftEquals(cmd.MatchEvent, otherCmd.MatchEvent) || otherCmd.MatchEvent == nil)
|
||||
}
|
@ -1,54 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
"maubot.xyz/database"
|
||||
)
|
||||
|
||||
type MainConfig struct {
|
||||
Logging LogConfig `yaml:"logging"`
|
||||
Database database.Database `yaml:"database"`
|
||||
Server ServerConfig `yaml:"server"`
|
||||
PluginDirs []string `yaml:"plugin_directories"`
|
||||
}
|
||||
|
||||
func (config *MainConfig) Load(path string) error {
|
||||
data, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return yaml.Unmarshal(data, config)
|
||||
}
|
||||
|
||||
func (config *MainConfig) Save(path string) error {
|
||||
data, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ioutil.WriteFile(path, data, 0644)
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Listen string `yaml:"listen"`
|
||||
BasePath string `yaml:"base_path"`
|
||||
}
|
@ -1,122 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"maunium.net/go/maulogger"
|
||||
)
|
||||
|
||||
// LogConfig contains configs for the logger.
|
||||
type LogConfig struct {
|
||||
Directory string `yaml:"directory"`
|
||||
FileNameFormat string `yaml:"file_name_format"`
|
||||
FileDateFormat string `yaml:"file_date_format"`
|
||||
FileMode uint32 `yaml:"file_mode"`
|
||||
TimestampFormat string `yaml:"timestamp_format"`
|
||||
RawPrintLevel string `yaml:"print_level"`
|
||||
PrintLevel int `yaml:"-"`
|
||||
}
|
||||
|
||||
type umLogConfig LogConfig
|
||||
|
||||
func (lc *LogConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
err := unmarshal((*umLogConfig)(lc))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch strings.ToUpper(lc.RawPrintLevel) {
|
||||
case "DEBUG":
|
||||
lc.PrintLevel = maulogger.LevelDebug.Severity
|
||||
case "INFO":
|
||||
lc.PrintLevel = maulogger.LevelInfo.Severity
|
||||
case "WARN", "WARNING":
|
||||
lc.PrintLevel = maulogger.LevelWarn.Severity
|
||||
case "ERR", "ERROR":
|
||||
lc.PrintLevel = maulogger.LevelError.Severity
|
||||
case "FATAL":
|
||||
lc.PrintLevel = maulogger.LevelFatal.Severity
|
||||
default:
|
||||
return errors.New("invalid print level " + lc.RawPrintLevel)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (lc *LogConfig) MarshalYAML() (interface{}, error) {
|
||||
switch {
|
||||
case lc.PrintLevel >= maulogger.LevelFatal.Severity:
|
||||
lc.RawPrintLevel = maulogger.LevelFatal.Name
|
||||
case lc.PrintLevel >= maulogger.LevelError.Severity:
|
||||
lc.RawPrintLevel = maulogger.LevelError.Name
|
||||
case lc.PrintLevel >= maulogger.LevelWarn.Severity:
|
||||
lc.RawPrintLevel = maulogger.LevelWarn.Name
|
||||
case lc.PrintLevel >= maulogger.LevelInfo.Severity:
|
||||
lc.RawPrintLevel = maulogger.LevelInfo.Name
|
||||
default:
|
||||
lc.RawPrintLevel = maulogger.LevelDebug.Name
|
||||
}
|
||||
return lc, nil
|
||||
}
|
||||
|
||||
// CreateLogConfig creates a basic LogConfig.
|
||||
func CreateLogConfig() LogConfig {
|
||||
return LogConfig{
|
||||
Directory: "./logs",
|
||||
FileNameFormat: "{{.Date}}-{{.Index}}.log",
|
||||
TimestampFormat: "Jan _2, 2006 15:04:05",
|
||||
FileMode: 0600,
|
||||
FileDateFormat: "2006-01-02",
|
||||
PrintLevel: 10,
|
||||
}
|
||||
}
|
||||
|
||||
type FileFormatData struct {
|
||||
Date string
|
||||
Index int
|
||||
}
|
||||
|
||||
// GetFileFormat returns a mauLogger-compatible logger file format based on the data in the struct.
|
||||
func (lc LogConfig) GetFileFormat() maulogger.LoggerFileFormat {
|
||||
os.MkdirAll(lc.Directory, 0700)
|
||||
path := filepath.Join(lc.Directory, lc.FileNameFormat)
|
||||
tpl, _ := template.New("fileformat").Parse(path)
|
||||
|
||||
return func(now string, i int) string {
|
||||
var buf strings.Builder
|
||||
tpl.Execute(&buf, FileFormatData{
|
||||
Date: now,
|
||||
Index: i,
|
||||
})
|
||||
return buf.String()
|
||||
}
|
||||
}
|
||||
|
||||
// Configure configures a mauLogger instance with the data in this struct.
|
||||
func (lc LogConfig) Configure(log maulogger.Logger) {
|
||||
basicLogger := log.(*maulogger.BasicLogger)
|
||||
basicLogger.FileFormat = lc.GetFileFormat()
|
||||
basicLogger.FileMode = os.FileMode(lc.FileMode)
|
||||
basicLogger.FileTimeFormat = lc.FileDateFormat
|
||||
basicLogger.TimeFormat = lc.TimestampFormat
|
||||
basicLogger.PrintLevel = lc.PrintLevel
|
||||
}
|
@ -1,161 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"maubot.xyz"
|
||||
log "maunium.net/go/maulogger"
|
||||
"database/sql"
|
||||
"sort"
|
||||
)
|
||||
|
||||
type MatrixClient struct {
|
||||
db *Database
|
||||
sql *sql.DB
|
||||
|
||||
UserID string `json:"user_id"`
|
||||
Homeserver string `json:"homeserver"`
|
||||
AccessToken string `json:"access_token"`
|
||||
NextBatch string `json:"next_batch"`
|
||||
FilterID string `json:"filter_id"`
|
||||
|
||||
Sync bool `json:"sync"`
|
||||
AutoJoinRooms bool `json:"auto_join_rooms"`
|
||||
DisplayName string `json:"display_name"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
|
||||
CommandSpecs map[string]*CommandSpec `json:"command_specs"`
|
||||
}
|
||||
|
||||
type MatrixClientStatic struct {
|
||||
db *Database
|
||||
sql *sql.DB
|
||||
}
|
||||
|
||||
func (mcs *MatrixClientStatic) CreateTable() error {
|
||||
_, err := mcs.sql.Exec(`CREATE TABLE IF NOT EXISTS matrix_client (
|
||||
user_id VARCHAR(255) PRIMARY KEY,
|
||||
homeserver VARCHAR(255) NOT NULL,
|
||||
access_token VARCHAR(255) NOT NULL,
|
||||
next_batch VARCHAR(255) NOT NULL,
|
||||
filter_id VARCHAR(255) NOT NULL,
|
||||
|
||||
sync BOOLEAN NOT NULL,
|
||||
autojoin BOOLEAN NOT NULL,
|
||||
display_name VARCHAR(255) NOT NULL,
|
||||
avatar_url VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
return err
|
||||
}
|
||||
|
||||
func (mcs *MatrixClientStatic) Get(userID string) *MatrixClient {
|
||||
row := mcs.sql.QueryRow("SELECT user_id, homeserver, access_token, next_batch, filter_id, sync, autojoin, display_name, avatar_url FROM matrix_client WHERE user_id=?", userID)
|
||||
if row != nil {
|
||||
return mcs.New().Scan(row)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mcs *MatrixClientStatic) GetAll() (clients []*MatrixClient) {
|
||||
rows, err := mcs.sql.Query("SELECT user_id, homeserver, access_token, next_batch, filter_id, sync, autojoin, display_name, avatar_url FROM matrix_client")
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
clients = append(clients, mcs.New().Scan(rows))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (mcs *MatrixClientStatic) New() *MatrixClient {
|
||||
return &MatrixClient{
|
||||
db: mcs.db,
|
||||
sql: mcs.sql,
|
||||
}
|
||||
}
|
||||
|
||||
func (mxc *MatrixClient) Scan(row Scannable) *MatrixClient {
|
||||
err := row.Scan(&mxc.UserID, &mxc.Homeserver, &mxc.AccessToken, &mxc.NextBatch, &mxc.FilterID, &mxc.Sync, &mxc.AutoJoinRooms, &mxc.DisplayName, &mxc.AvatarURL)
|
||||
if err != nil {
|
||||
log.Errorln("MatrixClient scan failed:", err)
|
||||
return mxc
|
||||
}
|
||||
mxc.LoadCommandSpecs()
|
||||
return mxc
|
||||
}
|
||||
|
||||
func (mxc *MatrixClient) SetCommandSpec(owner string, newSpec *maubot.CommandSpec) bool {
|
||||
spec, ok := mxc.CommandSpecs[owner]
|
||||
if ok && newSpec.Equals(spec.CommandSpec) {
|
||||
return false
|
||||
}
|
||||
if spec == nil {
|
||||
spec = mxc.db.CommandSpec.New()
|
||||
spec.CommandSpec = newSpec
|
||||
spec.Insert()
|
||||
} else {
|
||||
spec.CommandSpec = newSpec
|
||||
spec.Update()
|
||||
}
|
||||
mxc.CommandSpecs[owner] = spec
|
||||
return true
|
||||
}
|
||||
|
||||
func (mxc *MatrixClient) LoadCommandSpecs() *MatrixClient {
|
||||
specs := mxc.db.CommandSpec.GetAllByClient(mxc.UserID)
|
||||
mxc.CommandSpecs = make(map[string]*CommandSpec)
|
||||
for _, spec := range specs {
|
||||
mxc.CommandSpecs[spec.Owner] = spec
|
||||
}
|
||||
return mxc
|
||||
}
|
||||
|
||||
func (mxc *MatrixClient) CommandSpecIDs() []string {
|
||||
keys := make([]string, len(mxc.CommandSpecs))
|
||||
i := 0
|
||||
for key := range mxc.CommandSpecs {
|
||||
keys[i] = key
|
||||
i++
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func (mxc *MatrixClient) Commands() *maubot.CommandSpec {
|
||||
if len(mxc.CommandSpecs) == 0 {
|
||||
return &maubot.CommandSpec{}
|
||||
}
|
||||
specIDs := mxc.CommandSpecIDs()
|
||||
spec := mxc.CommandSpecs[specIDs[0]].Clone()
|
||||
for _, specID := range specIDs[1:] {
|
||||
spec.Merge(mxc.CommandSpecs[specID].CommandSpec)
|
||||
}
|
||||
return spec
|
||||
}
|
||||
|
||||
func (mxc *MatrixClient) Insert() error {
|
||||
_, err := mxc.sql.Exec("INSERT INTO matrix_client (user_id, homeserver, access_token, next_batch, filter_id, sync, autojoin, display_name, avatar_url) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
mxc.UserID, mxc.Homeserver, mxc.AccessToken, mxc.NextBatch, mxc.FilterID, mxc.Sync, mxc.AutoJoinRooms, mxc.DisplayName, mxc.AvatarURL)
|
||||
return err
|
||||
}
|
||||
|
||||
func (mxc *MatrixClient) Update() error {
|
||||
_, err := mxc.sql.Exec("UPDATE matrix_client SET access_token=?, next_batch=?, filter_id=?, sync=?, autojoin=?, display_name=?, avatar_url=? WHERE user_id=?",
|
||||
mxc.AccessToken, mxc.NextBatch, mxc.FilterID, mxc.Sync, mxc.AutoJoinRooms, mxc.DisplayName, mxc.AvatarURL, mxc.UserID)
|
||||
return err
|
||||
}
|
@ -1,136 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"maubot.xyz"
|
||||
log "maunium.net/go/maulogger"
|
||||
)
|
||||
|
||||
type CommandSpec struct {
|
||||
db *Database
|
||||
sql *sql.DB
|
||||
|
||||
*maubot.CommandSpec
|
||||
Owner string `json:"owner"`
|
||||
Client string `json:"client"`
|
||||
}
|
||||
|
||||
type CommandSpecStatic struct {
|
||||
db *Database
|
||||
sql *sql.DB
|
||||
}
|
||||
|
||||
func (css *CommandSpecStatic) CreateTable() error {
|
||||
_, err := css.sql.Exec(`CREATE TABLE IF NOT EXISTS command_spec (
|
||||
owner VARCHAR(255),
|
||||
client VARCHAR(255),
|
||||
spec TEXT,
|
||||
|
||||
PRIMARY KEY (owner, client),
|
||||
FOREIGN KEY (owner) REFERENCES plugin(id)
|
||||
ON DELETE CASCADE ON UPDATE CASCADE,
|
||||
FOREIGN KEY (client) REFERENCES matrix_client(user_id)
|
||||
ON DELETE CASCADE ON UPDATE CASCADE
|
||||
)`)
|
||||
return err
|
||||
}
|
||||
|
||||
func (css *CommandSpecStatic) Get(owner, client string) *CommandSpec {
|
||||
rows, err := css.sql.Query("SELECT * FROM command_spec WHERE owner=? AND client=?", owner, client)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to Get(%s, %s): %v\n", owner, client, err)
|
||||
}
|
||||
return css.New().Scan(rows)
|
||||
}
|
||||
|
||||
func (css *CommandSpecStatic) GetOrCreate(owner, client string) (spec *CommandSpec) {
|
||||
spec = css.Get(owner, client)
|
||||
if spec == nil {
|
||||
spec = css.New()
|
||||
spec.Owner = owner
|
||||
spec.Client = client
|
||||
spec.Insert()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (css *CommandSpecStatic) getAllByQuery(query string, args ...interface{}) (specs []*CommandSpec) {
|
||||
rows, err := css.sql.Query(query, args...)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to getAllByQuery(%s): %v\n", query, err)
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
specs = append(specs, css.New().Scan(rows))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (css *CommandSpecStatic) GetAllByOwner(owner string) []*CommandSpec {
|
||||
return css.getAllByQuery("SELECT * FROM command_spec WHERE owner=?", owner)
|
||||
}
|
||||
|
||||
func (css *CommandSpecStatic) GetAllByClient(client string) []*CommandSpec {
|
||||
return css.getAllByQuery("SELECT * FROM command_spec WHERE client=?", client)
|
||||
}
|
||||
|
||||
func (css *CommandSpecStatic) New() *CommandSpec {
|
||||
return &CommandSpec{
|
||||
db: css.db,
|
||||
sql: css.sql,
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *CommandSpec) Scan(row Scannable) *CommandSpec {
|
||||
var spec string
|
||||
err := row.Scan(&cs.Owner, &cs.Client, &spec)
|
||||
if err != nil {
|
||||
log.Errorln("CommandSpec scan failed:", err)
|
||||
return cs
|
||||
}
|
||||
cs.CommandSpec = &maubot.CommandSpec{}
|
||||
err = json.Unmarshal([]byte(spec), cs.CommandSpec)
|
||||
if err != nil {
|
||||
log.Errorln("CommandSpec parse failed:", err)
|
||||
}
|
||||
return cs
|
||||
}
|
||||
|
||||
func (cs *CommandSpec) Insert() error {
|
||||
data, err := json.Marshal(cs.CommandSpec)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = cs.sql.Exec("INSERT INTO command_spec (owner, client, spec) VALUES (?, ?, ?)",
|
||||
cs.Owner, cs.Client, string(data))
|
||||
return err
|
||||
}
|
||||
|
||||
func (cs *CommandSpec) Update() error {
|
||||
data, err := json.Marshal(cs.CommandSpec)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = cs.sql.Exec("UPDATE command_spec SET spec=? WHERE owner=? AND client=?",
|
||||
string(data), cs.Owner, cs.Client)
|
||||
return err
|
||||
}
|
@ -1,74 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
log "maunium.net/go/maulogger"
|
||||
)
|
||||
|
||||
type Scannable interface {
|
||||
Scan(...interface{}) error
|
||||
}
|
||||
|
||||
type Database struct {
|
||||
Type string `yaml:"type"`
|
||||
Name string `yaml:"name"`
|
||||
|
||||
MatrixClient *MatrixClientStatic `yaml:"-"`
|
||||
Plugin *PluginStatic `yaml:"-"`
|
||||
CommandSpec *CommandSpecStatic `yaml:"-"`
|
||||
|
||||
sql *sql.DB
|
||||
}
|
||||
|
||||
func (db *Database) Connect() (err error) {
|
||||
db.sql, err = sql.Open(db.Type, db.Name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
db.MatrixClient = &MatrixClientStatic{db: db, sql: db.sql}
|
||||
db.Plugin = &PluginStatic{db: db, sql: db.sql}
|
||||
db.CommandSpec = &CommandSpecStatic{db: db, sql: db.sql}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Database) CreateTables() {
|
||||
log.Debugln("Creating database tables")
|
||||
|
||||
err := db.MatrixClient.CreateTable()
|
||||
if err != nil {
|
||||
log.Errorln("Failed to create matrix_client table:", err)
|
||||
}
|
||||
|
||||
err = db.Plugin.CreateTable()
|
||||
if err != nil {
|
||||
log.Errorln("Failed to create plugin table:", err)
|
||||
}
|
||||
|
||||
err = db.CommandSpec.CreateTable()
|
||||
if err != nil {
|
||||
log.Errorln("Failed to create command_spec table:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (db *Database) SQL() *sql.DB {
|
||||
return db.sql
|
||||
}
|
@ -1,105 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
log "maunium.net/go/maulogger"
|
||||
)
|
||||
|
||||
type Plugin struct {
|
||||
db *Database
|
||||
sql *sql.DB
|
||||
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Enabled bool `json:"enabled"`
|
||||
UserID string `json:"user_id"`
|
||||
//User *MatrixClient `json:"-"`
|
||||
}
|
||||
|
||||
type PluginStatic struct {
|
||||
db *Database
|
||||
sql *sql.DB
|
||||
}
|
||||
|
||||
func (ps *PluginStatic) CreateTable() error {
|
||||
_, err := ps.sql.Exec(`CREATE TABLE IF NOT EXISTS plugin (
|
||||
id VARCHAR(255) PRIMARY KEY,
|
||||
type VARCHAR(255) NOT NULL,
|
||||
enabled BOOLEAN NOT NULL,
|
||||
|
||||
user_id VARCHAR(255) NOT NULL,
|
||||
|
||||
FOREIGN KEY (user_id) REFERENCES matrix_client(user_id)
|
||||
ON DELETE RESTRICT ON UPDATE CASCADE
|
||||
)`)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ps *PluginStatic) Get(id string) *Plugin {
|
||||
row := ps.sql.QueryRow("SELECT * FROM plugin WHERE id=?", id)
|
||||
if row != nil {
|
||||
return ps.New().Scan(row)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ps *PluginStatic) GetAll() (plugins []*Plugin) {
|
||||
rows, err := ps.sql.Query("SELECT * FROM plugin")
|
||||
if err != nil || rows == nil {
|
||||
return nil
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
plugins = append(plugins, ps.New().Scan(rows))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (ps *PluginStatic) New() *Plugin {
|
||||
return &Plugin{
|
||||
db: ps.db,
|
||||
sql: ps.sql,
|
||||
}
|
||||
}
|
||||
|
||||
/*func (p *Plugin) LoadUser() *Plugin {
|
||||
p.User = p.db.MatrixClient.Get(p.UserID)
|
||||
return p
|
||||
}*/
|
||||
|
||||
func (p *Plugin) Scan(row Scannable) *Plugin {
|
||||
err := row.Scan(&p.ID, &p.Type, &p.Enabled, &p.UserID)
|
||||
if err != nil {
|
||||
log.Errorln("Plugin scan failed:", err)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *Plugin) Insert() error {
|
||||
_, err := p.sql.Exec("INSERT INTO plugin (id, type, enabled, user_id) VALUES (?, ?, ?, ?)",
|
||||
p.ID, p.Type, p.Enabled, p.UserID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Plugin) Update() error {
|
||||
_, err := p.sql.Exec("UPDATE plugin SET enabled=? WHERE id=?",
|
||||
p.Enabled, p.ID)
|
||||
return err
|
||||
}
|
28
docker-run.sh
Normal file
28
docker-run.sh
Normal file
@ -0,0 +1,28 @@
|
||||
#!/bin/sh
|
||||
|
||||
# Define functions.
|
||||
function fixperms {
|
||||
chown -R $UID:$GID /data /opt/maubot
|
||||
}
|
||||
|
||||
cd /opt/maubot
|
||||
|
||||
# Replace database path in config.
|
||||
sed -i "s#sqlite:///maubot.db#sqlite:////data/maubot.db#" /data/config.yaml
|
||||
sed -i "s#- ./plugins#- /data/plugins#" /data/config.yaml
|
||||
|
||||
# Check that database is in the right state
|
||||
alembic -x config=/data/config.yaml upgrade head
|
||||
|
||||
if [ ! -f /data/config.yaml ]; then
|
||||
cp example-config.yaml /data/config.yaml
|
||||
echo "Didn't find a config file."
|
||||
echo "Copied default config file to /data/config.yaml"
|
||||
echo "Modify that config file to your liking."
|
||||
echo "Start the container again after that to generate the registration file."
|
||||
fixperms
|
||||
exit
|
||||
fi
|
||||
|
||||
fixperms
|
||||
exec su-exec $UID:$GID python3 -m maubot -c /data/config.yaml
|
@ -1,19 +1,55 @@
|
||||
database:
|
||||
type: sqlite3
|
||||
name: maubot.db
|
||||
|
||||
logging:
|
||||
directory: ./logs
|
||||
file_mode: 0600
|
||||
print_level: DEBUG
|
||||
file_name_format: "{{.Date}}-{{.Index}}.log"
|
||||
file_date_format: 2006-01-02
|
||||
timestamp_format: Jan _2, 2006 15:04:05
|
||||
# The full URI to the database. SQLite and Postgres are fully supported.
|
||||
# Other DBMSes supported by SQLAlchemy may or may not work.
|
||||
# Format examples:
|
||||
# SQLite: sqlite:///filename.db
|
||||
# Postgres: postgres://username:password@hostname/dbname
|
||||
database: sqlite:///maubot.db
|
||||
|
||||
# If multiple directories have a plugin with the same name, the first directory is used.
|
||||
plugin_directories:
|
||||
- ./plugins
|
||||
|
||||
server:
|
||||
listen: 0.0.0.0:29316
|
||||
base_path: /_matrix/maubot
|
||||
# The IP and port to listen to.
|
||||
hostname: 0.0.0.0
|
||||
port: 29316
|
||||
# The base management API path.
|
||||
base_path: /_matrix/maubot
|
||||
# The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1.
|
||||
appservice_base_path: /_matrix/app/v1
|
||||
# The shared secret to authorize users of the API.
|
||||
# Set to "generate" to generate and save a new token at startup.
|
||||
shared_secret: generate
|
||||
|
||||
admins:
|
||||
- "@admin:example.com"
|
||||
|
||||
# Python logging configuration.
|
||||
#
|
||||
# See section 16.7.2 of the Python documentation for more info:
|
||||
# https://docs.python.org/3.6/library/logging.config.html#configuration-dictionary-schema
|
||||
logging:
|
||||
version: 1
|
||||
formatters:
|
||||
precise:
|
||||
format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s"
|
||||
handlers:
|
||||
file:
|
||||
class: logging.handlers.RotatingFileHandler
|
||||
formatter: precise
|
||||
filename: ./logs/maubot.log
|
||||
maxBytes: 10485760
|
||||
backupCount: 10
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
formatter: precise
|
||||
loggers:
|
||||
maubot:
|
||||
level: DEBUG
|
||||
mautrix:
|
||||
level: DEBUG
|
||||
aiohttp:
|
||||
level: INFO
|
||||
root:
|
||||
level: DEBUG
|
||||
handlers: [file, console]
|
||||
|
16
example-maubot.ini
Normal file
16
example-maubot.ini
Normal file
@ -0,0 +1,16 @@
|
||||
# This is an example maubot plugin definition file.
|
||||
# All plugins must include a file like this named "maubot.ini" in their root directory.
|
||||
[maubot]
|
||||
# The unique ID for the plugin. Java package naming style.
|
||||
ID = xyz.maubot.plugin
|
||||
# A PEP 440 compliant version string.
|
||||
Version = 1.0.0
|
||||
# The comma-separated list of modules to load from the plugin archive.
|
||||
# Submodules that are imported by modules listed here don't need to be listed separately.
|
||||
# However, top-level modules must always be listed even if they're imported by other modules.
|
||||
Modules = plugin
|
||||
# The main class of the plugin. Format: module/Class
|
||||
# If `module` is omitted, will default to last module specified in the module list.
|
||||
# Even if `module` is not omitted here, it must be included in the modules list.
|
||||
# The main class must extend maubot.Plugin
|
||||
MainClass = PluginClass
|
41
logging.go
41
logging.go
@ -1,41 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package maubot
|
||||
|
||||
type Logger interface {
|
||||
Write(p []byte) (n int, err error)
|
||||
Debug(parts ...interface{})
|
||||
Debugln(parts ...interface{})
|
||||
Debugf(message string, args ...interface{})
|
||||
Debugfln(message string, args ...interface{})
|
||||
Info(parts ...interface{})
|
||||
Infoln(parts ...interface{})
|
||||
Infof(message string, args ...interface{})
|
||||
Infofln(message string, args ...interface{})
|
||||
Warn(parts ...interface{})
|
||||
Warnln(parts ...interface{})
|
||||
Warnf(message string, args ...interface{})
|
||||
Warnfln(message string, args ...interface{})
|
||||
Error(parts ...interface{})
|
||||
Errorln(parts ...interface{})
|
||||
Errorf(message string, args ...interface{})
|
||||
Errorfln(message string, args ...interface{})
|
||||
Fatal(parts ...interface{})
|
||||
Fatalln(parts ...interface{})
|
||||
Fatalf(message string, args ...interface{})
|
||||
Fatalfln(message string, args ...interface{})
|
||||
}
|
125
matrix.go
125
matrix.go
@ -1,125 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package maubot
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"maunium.net/go/gomatrix"
|
||||
)
|
||||
|
||||
type EventHandler func(*Event) EventHandlerResult
|
||||
type EventHandlerResult int
|
||||
type CommandHandlerResult = EventHandlerResult
|
||||
|
||||
const (
|
||||
Continue EventHandlerResult = iota
|
||||
StopEventPropagation
|
||||
StopCommandPropagation CommandHandlerResult = iota
|
||||
)
|
||||
|
||||
type GomatrixClient interface {
|
||||
//d <method> = disabled
|
||||
//r <method> = replaced
|
||||
|
||||
BanUser(roomID string, req *gomatrix.ReqBanUser) (resp *gomatrix.RespBanUser, err error)
|
||||
//d BuildBaseURL(urlPath ...string) string
|
||||
//d BuildURL(urlPath ...string) string
|
||||
//d BuildURLWithQuery(urlPath []string, urlQuery map[string]string) string
|
||||
//d ClearCredentials()
|
||||
//d CreateFilter(filter json.RawMessage) (resp *gomatrix.RespCreateFilter, err error)
|
||||
CreateRoom(req *gomatrix.ReqCreateRoom) (resp *gomatrix.RespCreateRoom, err error)
|
||||
Download(mxcURL string) (io.ReadCloser, error)
|
||||
DownloadBytes(mxcURL string) ([]byte, error)
|
||||
ForgetRoom(roomID string) (resp *gomatrix.RespForgetRoom, err error)
|
||||
GetAvatarURL() (url string, err error)
|
||||
GetDisplayName(mxid string) (resp *gomatrix.RespUserDisplayName, err error)
|
||||
//r GetEvent(roomID, eventID string) (resp *gomatrix.Event, err error)
|
||||
GetOwnDisplayName() (resp *gomatrix.RespUserDisplayName, err error)
|
||||
InviteUser(roomID string, req *gomatrix.ReqInviteUser) (resp *gomatrix.RespInviteUser, err error)
|
||||
InviteUserByThirdParty(roomID string, req *gomatrix.ReqInvite3PID) (resp *gomatrix.RespInviteUser, err error)
|
||||
//r JoinRoom(roomIDorAlias, serverName string, content interface{}) (resp *gomatrix.RespJoinRoom, err error)
|
||||
JoinedMembers(roomID string) (resp *gomatrix.RespJoinedMembers, err error)
|
||||
JoinedRooms() (resp *gomatrix.RespJoinedRooms, err error)
|
||||
KickUser(roomID string, req *gomatrix.ReqKickUser) (resp *gomatrix.RespKickUser, err error)
|
||||
LeaveRoom(roomID string) (resp *gomatrix.RespLeaveRoom, err error)
|
||||
//d Login(req *gomatrix.ReqLogin) (resp *gomatrix.RespLogin, err error)
|
||||
//d Logout() (resp *gomatrix.RespLogout, err error)
|
||||
MakeRequest(method string, httpURL string, reqBody interface{}, resBody interface{}) ([]byte, error)
|
||||
MarkRead(roomID, eventID string) (err error)
|
||||
Messages(roomID, from, to string, dir rune, limit int) (resp *gomatrix.RespMessages, err error)
|
||||
RedactEvent(roomID, eventID string, req *gomatrix.ReqRedact) (resp *gomatrix.RespSendEvent, err error)
|
||||
//d Register(req *gomatrix.ReqRegister) (*gomatrix.RespRegister, *gomatrix.RespUserInteractive, error)
|
||||
//d RegisterDummy(req *gomatrix.ReqRegister) (*gomatrix.RespRegister, error)
|
||||
//d RegisterGuest(req *gomatrix.ReqRegister) (*gomatrix.RespRegister, *gomatrix.RespUserInteractive, error)
|
||||
SendImage(roomID, body, url string) (*gomatrix.RespSendEvent, error)
|
||||
//SendMassagedMessageEvent(roomID string, eventType gomatrix.EventType, contentJSON interface{}, ts int64) (resp *gomatrix.RespSendEvent, err error)
|
||||
//SendMassagedStateEvent(roomID string, eventType gomatrix.EventType, stateKey string, contentJSON interface{}, ts int64) (resp *gomatrix.RespSendEvent, err error)
|
||||
//r SendMessageEvent(roomID string, eventType gomatrix.EventType, contentJSON interface{}) (resp *gomatrix.RespSendEvent, err error)
|
||||
SendNotice(roomID, text string) (*gomatrix.RespSendEvent, error)
|
||||
SendStateEvent(roomID string, eventType gomatrix.EventType, stateKey string, contentJSON interface{}) (resp *gomatrix.RespSendEvent, err error)
|
||||
SendText(roomID, text string) (*gomatrix.RespSendEvent, error)
|
||||
SendVideo(roomID, body, url string) (*gomatrix.RespSendEvent, error)
|
||||
SetAvatarURL(url string) (err error)
|
||||
SetCredentials(userID, accessToken string)
|
||||
SetDisplayName(displayName string) (err error)
|
||||
SetPresence(status string) (err error)
|
||||
StateEvent(roomID string, eventType gomatrix.EventType, stateKey string, outContent interface{}) (err error)
|
||||
//d StopSync()
|
||||
//d Sync() error
|
||||
//d SyncRequest(timeout int, since, filterID string, fullState bool, setPresence string) (resp *gomatrix.RespSync, err error)
|
||||
TurnServer() (resp *gomatrix.RespTurnServer, err error)
|
||||
UnbanUser(roomID string, req *gomatrix.ReqUnbanUser) (resp *gomatrix.RespUnbanUser, err error)
|
||||
Upload(content io.Reader, contentType string, contentLength int64) (*gomatrix.RespMediaUpload, error)
|
||||
UploadBytes(data []byte, contentType string) (*gomatrix.RespMediaUpload, error)
|
||||
UploadLink(link string) (*gomatrix.RespMediaUpload, error)
|
||||
UserTyping(roomID string, typing bool, timeout int64) (resp *gomatrix.RespTyping, err error)
|
||||
Versions() (resp *gomatrix.RespVersions, err error)
|
||||
}
|
||||
|
||||
type MBMatrixClient interface {
|
||||
AddEventHandler(gomatrix.EventType, EventHandler)
|
||||
AddCommandHandler(string, CommandHandler)
|
||||
SetCommandSpec(*CommandSpec)
|
||||
|
||||
GetEvent(roomID, eventID string) *Event
|
||||
JoinRoom(roomIDOrAlias string) (resp *gomatrix.RespJoinRoom, err error)
|
||||
SendMessage(roomID, text string) (eventID string, err error)
|
||||
SendMessagef(roomID, text string, args ...interface{}) (eventID string, err error)
|
||||
SendContent(roomID string, content gomatrix.Content) (eventID string, err error)
|
||||
SendMessageEvent(roomID string, evtType gomatrix.EventType, content interface{}) (eventID string, err error)
|
||||
}
|
||||
|
||||
type MatrixClient interface {
|
||||
GomatrixClient
|
||||
MBMatrixClient
|
||||
}
|
||||
|
||||
type EventFuncs interface {
|
||||
MarkRead() error
|
||||
Reply(string) (string, error)
|
||||
ReplyContent(gomatrix.Content) (string, error)
|
||||
SendMessage(string) (string, error)
|
||||
SendMessagef(string, ...interface{}) (string, error)
|
||||
SendContent(gomatrix.Content) (string, error)
|
||||
SendMessageEvent(evtType gomatrix.EventType, content interface{}) (eventID string, err error)
|
||||
}
|
||||
|
||||
type Event struct {
|
||||
EventFuncs
|
||||
*gomatrix.Event
|
||||
}
|
@ -1,208 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"maunium.net/go/gomatrix"
|
||||
log "maunium.net/go/maulogger"
|
||||
|
||||
"maubot.xyz"
|
||||
)
|
||||
|
||||
type ParsedCommand struct {
|
||||
Name string
|
||||
IsPassive bool
|
||||
Arguments []string
|
||||
StartsWith string
|
||||
Matches *regexp.Regexp
|
||||
MatchAgainst string
|
||||
MatchesEvent interface{}
|
||||
}
|
||||
|
||||
func (pc *ParsedCommand) parseCommandSyntax(command maubot.Command) error {
|
||||
regexBuilder := &strings.Builder{}
|
||||
swBuilder := &strings.Builder{}
|
||||
argumentEncountered := false
|
||||
|
||||
regexBuilder.WriteString("^!")
|
||||
swBuilder.WriteRune('!')
|
||||
words := strings.Split(command.Syntax, " ")
|
||||
for i, word := range words {
|
||||
argument, ok := command.Arguments[word]
|
||||
// TODO enable $ check?
|
||||
if ok && len(word) > 0 /*&& word[0] == '$'*/ {
|
||||
argumentEncountered = true
|
||||
regex := argument.Matches
|
||||
if !argument.Required {
|
||||
regex = fmt.Sprintf("(?:%s)?", regex)
|
||||
} else {
|
||||
regex = fmt.Sprintf("(%s)", regex)
|
||||
}
|
||||
pc.Arguments = append(pc.Arguments, word)
|
||||
regexBuilder.WriteString(regex)
|
||||
} else {
|
||||
if !argumentEncountered {
|
||||
swBuilder.WriteString(word)
|
||||
}
|
||||
regexBuilder.WriteString(regexp.QuoteMeta(word))
|
||||
}
|
||||
|
||||
if i < len(words)-1 {
|
||||
if !argumentEncountered {
|
||||
swBuilder.WriteRune(' ')
|
||||
}
|
||||
regexBuilder.WriteRune(' ')
|
||||
}
|
||||
}
|
||||
regexBuilder.WriteRune('$')
|
||||
|
||||
var err error
|
||||
pc.StartsWith = swBuilder.String()
|
||||
// Trim the extra space at the end added in the parse loop
|
||||
pc.StartsWith = pc.StartsWith[:len(pc.StartsWith)-1]
|
||||
pc.Matches, err = regexp.Compile(regexBuilder.String())
|
||||
pc.MatchAgainst = "body"
|
||||
return err
|
||||
}
|
||||
|
||||
func (pc *ParsedCommand) parsePassiveCommandSyntax(command maubot.PassiveCommand) error {
|
||||
pc.MatchAgainst = command.MatchAgainst
|
||||
var err error
|
||||
pc.Matches, err = regexp.Compile(command.Matches)
|
||||
pc.MatchesEvent = command.MatchEvent
|
||||
return err
|
||||
}
|
||||
|
||||
func ParseSpec(spec *maubot.CommandSpec) (commands []*ParsedCommand) {
|
||||
for _, command := range spec.Commands {
|
||||
parsing := &ParsedCommand{
|
||||
Name: command.Syntax,
|
||||
IsPassive: false,
|
||||
}
|
||||
err := parsing.parseCommandSyntax(command)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to parse regex of command %s: %v\n", command.Syntax, err)
|
||||
continue
|
||||
}
|
||||
commands = append(commands, parsing)
|
||||
}
|
||||
for _, command := range spec.PassiveCommands {
|
||||
parsing := &ParsedCommand{
|
||||
Name: command.Name,
|
||||
IsPassive: true,
|
||||
}
|
||||
err := parsing.parsePassiveCommandSyntax(command)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to parse regex of passive command %s: %v\n", command.Name, err)
|
||||
continue
|
||||
}
|
||||
commands = append(commands, parsing)
|
||||
}
|
||||
return commands
|
||||
}
|
||||
|
||||
func deepGet(from map[string]interface{}, path string) interface{} {
|
||||
for {
|
||||
dotIndex := strings.IndexRune(path, '.')
|
||||
if dotIndex == -1 {
|
||||
return from[path]
|
||||
}
|
||||
|
||||
var key string
|
||||
key, path = path[:dotIndex], path[dotIndex+1:]
|
||||
var ok bool
|
||||
from, ok = from[key].(map[string]interface{})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pc *ParsedCommand) MatchActive(evt *gomatrix.Event) bool {
|
||||
if !strings.HasPrefix(evt.Content.Body, pc.StartsWith) {
|
||||
return false
|
||||
}
|
||||
match := pc.Matches.FindStringSubmatch(evt.Content.Body)
|
||||
if match == nil {
|
||||
return false
|
||||
}
|
||||
// First element is whole content
|
||||
match = match[1:]
|
||||
|
||||
command := &gomatrix.MatchedCommand{
|
||||
Arguments: make(map[string]string),
|
||||
}
|
||||
for i, value := range match {
|
||||
if i >= len(pc.Arguments) {
|
||||
break
|
||||
}
|
||||
key := pc.Arguments[i]
|
||||
command.Arguments[key] = value
|
||||
}
|
||||
|
||||
command.Matched = pc.Name
|
||||
// TODO add evt.Content.Command.Target?
|
||||
evt.Content.Command = command
|
||||
return true
|
||||
}
|
||||
|
||||
func (pc *ParsedCommand) MatchPassive(evt *gomatrix.Event) bool {
|
||||
matchAgainst := evt.Content.Body
|
||||
switch pc.MatchAgainst {
|
||||
case maubot.MatchAgainstBody:
|
||||
matchAgainst = evt.Content.Body
|
||||
case "formatted_body":
|
||||
matchAgainst = evt.Content.FormattedBody
|
||||
default:
|
||||
matchAgainstDirect, ok := deepGet(evt.Content.Raw, pc.MatchAgainst).(string)
|
||||
if ok {
|
||||
matchAgainst = matchAgainstDirect
|
||||
}
|
||||
}
|
||||
|
||||
if pc.MatchesEvent != nil && !maubot.JSONLeftEquals(pc.MatchesEvent, evt) {
|
||||
return false
|
||||
}
|
||||
|
||||
matches := pc.Matches.FindAllStringSubmatch(matchAgainst, -1)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if evt.Unsigned.PassiveCommand == nil {
|
||||
evt.Unsigned.PassiveCommand = make(map[string]*gomatrix.MatchedPassiveCommand)
|
||||
}
|
||||
evt.Unsigned.PassiveCommand[pc.Name] = &gomatrix.MatchedPassiveCommand{
|
||||
Captured: matches,
|
||||
}
|
||||
//evt.Unsigned.PassiveCommand.Matched = pc.Name
|
||||
//evt.Unsigned.PassiveCommand.Captured = matches
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (pc *ParsedCommand) Match(evt *gomatrix.Event) bool {
|
||||
if pc.IsPassive {
|
||||
return pc.MatchPassive(evt)
|
||||
} else {
|
||||
return pc.MatchActive(evt)
|
||||
}
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"maubot.xyz"
|
||||
"maunium.net/go/gomatrix"
|
||||
"maunium.net/go/gomatrix/format"
|
||||
)
|
||||
|
||||
type EventFuncsImpl struct {
|
||||
*gomatrix.Event
|
||||
Client *Client
|
||||
}
|
||||
|
||||
func (client *Client) ParseEvent(mxEvent *gomatrix.Event) *maubot.Event {
|
||||
if mxEvent == nil {
|
||||
return nil
|
||||
}
|
||||
mxEvent.Content.RemoveReplyFallback()
|
||||
return &maubot.Event{
|
||||
EventFuncs: &EventFuncsImpl{
|
||||
Event: mxEvent,
|
||||
Client: client,
|
||||
},
|
||||
Event: mxEvent,
|
||||
}
|
||||
}
|
||||
|
||||
func (evt *EventFuncsImpl) MarkRead() error {
|
||||
return evt.Client.MarkRead(evt.RoomID, evt.ID)
|
||||
}
|
||||
|
||||
func (evt *EventFuncsImpl) Reply(text string) (string, error) {
|
||||
content := format.RenderMarkdown(text)
|
||||
content.MsgType = gomatrix.MsgNotice
|
||||
return evt.ReplyContent(content)
|
||||
}
|
||||
|
||||
func (evt *EventFuncsImpl) ReplyContent(content gomatrix.Content) (string, error) {
|
||||
content.SetReply(evt.Event)
|
||||
return evt.SendContent(content)
|
||||
}
|
||||
|
||||
func (evt *EventFuncsImpl) SendMessage(text string) (string, error) {
|
||||
return evt.Client.SendMessage(evt.RoomID, text)
|
||||
}
|
||||
|
||||
func (evt *EventFuncsImpl) SendMessagef(text string, args ...interface{}) (string, error) {
|
||||
return evt.Client.SendMessagef(evt.RoomID, text, args...)
|
||||
}
|
||||
|
||||
func (evt *EventFuncsImpl) SendContent(content gomatrix.Content) (string, error) {
|
||||
return evt.Client.SendContent(evt.RoomID, content)
|
||||
}
|
||||
|
||||
func (evt *EventFuncsImpl) SendMessageEvent(evtType gomatrix.EventType, content interface{}) (eventID string, err error) {
|
||||
return evt.Client.SendMessageEvent(evt.RoomID, evtType, content)
|
||||
}
|
@ -1,229 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/html"
|
||||
)
|
||||
|
||||
var matrixToURL = regexp.MustCompile("^(?:https?://)?(?:www\\.)?matrix\\.to/#/([#@!].*)")
|
||||
|
||||
type htmlParser struct{}
|
||||
|
||||
type taggedString struct {
|
||||
string
|
||||
tag string
|
||||
}
|
||||
|
||||
func (parser *htmlParser) getAttribute(node *html.Node, attribute string) string {
|
||||
for _, attr := range node.Attr {
|
||||
if attr.Key == attribute {
|
||||
return attr.Val
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func digits(num int) int {
|
||||
return int(math.Floor(math.Log10(float64(num))) + 1)
|
||||
}
|
||||
|
||||
func (parser *htmlParser) listToString(node *html.Node, stripLinebreak bool) string {
|
||||
ordered := node.Data == "ol"
|
||||
taggedChildren := parser.nodeToTaggedStrings(node.FirstChild, stripLinebreak)
|
||||
counter := 1
|
||||
indentLength := 0
|
||||
if ordered {
|
||||
start := parser.getAttribute(node, "start")
|
||||
if len(start) > 0 {
|
||||
counter, _ = strconv.Atoi(start)
|
||||
}
|
||||
|
||||
longestIndex := (counter - 1) + len(taggedChildren)
|
||||
indentLength = digits(longestIndex)
|
||||
}
|
||||
indent := strings.Repeat(" ", indentLength+2)
|
||||
var children []string
|
||||
for _, child := range taggedChildren {
|
||||
if child.tag != "li" {
|
||||
continue
|
||||
}
|
||||
var prefix string
|
||||
if ordered {
|
||||
indexPadding := indentLength - digits(counter)
|
||||
prefix = fmt.Sprintf("%d. %s", counter, strings.Repeat(" ", indexPadding))
|
||||
} else {
|
||||
prefix = "● "
|
||||
}
|
||||
str := prefix + child.string
|
||||
counter++
|
||||
parts := strings.Split(str, "\n")
|
||||
for i, part := range parts[1:] {
|
||||
parts[i+1] = indent + part
|
||||
}
|
||||
str = strings.Join(parts, "\n")
|
||||
children = append(children, str)
|
||||
}
|
||||
return strings.Join(children, "\n")
|
||||
}
|
||||
|
||||
func (parser *htmlParser) basicFormatToString(node *html.Node, stripLinebreak bool) string {
|
||||
str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak)
|
||||
switch node.Data {
|
||||
case "b", "strong":
|
||||
return fmt.Sprintf("**%s**", str)
|
||||
case "i", "em":
|
||||
return fmt.Sprintf("_%s_", str)
|
||||
case "s", "del":
|
||||
return fmt.Sprintf("~~%s~~", str)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func (parser *htmlParser) headerToString(node *html.Node, stripLinebreak bool) string {
|
||||
children := parser.nodeToStrings(node.FirstChild, stripLinebreak)
|
||||
length := int(node.Data[1] - '0')
|
||||
prefix := strings.Repeat("#", length) + " "
|
||||
return prefix + strings.Join(children, "")
|
||||
}
|
||||
|
||||
func (parser *htmlParser) blockquoteToString(node *html.Node, stripLinebreak bool) string {
|
||||
str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak)
|
||||
childrenArr := strings.Split(strings.TrimSpace(str), "\n")
|
||||
for index, child := range childrenArr {
|
||||
childrenArr[index] = "> " + child
|
||||
}
|
||||
return strings.Join(childrenArr, "\n")
|
||||
}
|
||||
|
||||
func (parser *htmlParser) linkToString(node *html.Node, stripLinebreak bool) string {
|
||||
str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak)
|
||||
href := parser.getAttribute(node, "href")
|
||||
if len(href) == 0 {
|
||||
return str
|
||||
}
|
||||
match := matrixToURL.FindStringSubmatch(href)
|
||||
if len(match) == 2 {
|
||||
// pillTarget := match[1]
|
||||
// if pillTarget[0] == '@' {
|
||||
// if member := parser.room.GetMember(pillTarget); member != nil {
|
||||
// return member.DisplayName
|
||||
// }
|
||||
// }
|
||||
// return pillTarget
|
||||
return str
|
||||
}
|
||||
return fmt.Sprintf("%s (%s)", str, href)
|
||||
}
|
||||
|
||||
func (parser *htmlParser) tagToString(node *html.Node, stripLinebreak bool) string {
|
||||
switch node.Data {
|
||||
case "blockquote":
|
||||
return parser.blockquoteToString(node, stripLinebreak)
|
||||
case "ol", "ul":
|
||||
return parser.listToString(node, stripLinebreak)
|
||||
case "h1", "h2", "h3", "h4", "h5", "h6":
|
||||
return parser.headerToString(node, stripLinebreak)
|
||||
case "br":
|
||||
return "\n"
|
||||
case "b", "strong", "i", "em", "s", "del", "u", "ins":
|
||||
return parser.basicFormatToString(node, stripLinebreak)
|
||||
case "a":
|
||||
return parser.linkToString(node, stripLinebreak)
|
||||
case "p":
|
||||
return parser.nodeToTagAwareString(node.FirstChild, stripLinebreak) + "\n"
|
||||
case "pre":
|
||||
return parser.nodeToString(node.FirstChild, false)
|
||||
default:
|
||||
return parser.nodeToTagAwareString(node.FirstChild, stripLinebreak)
|
||||
}
|
||||
}
|
||||
|
||||
func (parser *htmlParser) singleNodeToString(node *html.Node, stripLinebreak bool) taggedString {
|
||||
switch node.Type {
|
||||
case html.TextNode:
|
||||
if stripLinebreak {
|
||||
node.Data = strings.Replace(node.Data, "\n", "", -1)
|
||||
}
|
||||
return taggedString{node.Data, "text"}
|
||||
case html.ElementNode:
|
||||
return taggedString{parser.tagToString(node, stripLinebreak), node.Data}
|
||||
case html.DocumentNode:
|
||||
return taggedString{parser.nodeToTagAwareString(node.FirstChild, stripLinebreak), "html"}
|
||||
default:
|
||||
return taggedString{"", "unknown"}
|
||||
}
|
||||
}
|
||||
|
||||
func (parser *htmlParser) nodeToTaggedStrings(node *html.Node, stripLinebreak bool) (strs []taggedString) {
|
||||
for ; node != nil; node = node.NextSibling {
|
||||
strs = append(strs, parser.singleNodeToString(node, stripLinebreak))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var BlockTags = []string{"p", "h1", "h2", "h3", "h4", "h5", "h6", "ol", "ul", "pre", "blockquote", "div", "hr", "table"}
|
||||
|
||||
func (parser *htmlParser) isBlockTag(tag string) bool {
|
||||
for _, blockTag := range BlockTags {
|
||||
if tag == blockTag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (parser *htmlParser) nodeToTagAwareString(node *html.Node, stripLinebreak bool) string {
|
||||
strs := parser.nodeToTaggedStrings(node, stripLinebreak)
|
||||
var output strings.Builder
|
||||
for _, str := range strs {
|
||||
tstr := str.string
|
||||
if parser.isBlockTag(str.tag) {
|
||||
tstr = fmt.Sprintf("\n%s\n", tstr)
|
||||
}
|
||||
output.WriteString(tstr)
|
||||
}
|
||||
return strings.TrimSpace(output.String())
|
||||
}
|
||||
|
||||
func (parser *htmlParser) nodeToStrings(node *html.Node, stripLinebreak bool) (strs []string) {
|
||||
for ; node != nil; node = node.NextSibling {
|
||||
strs = append(strs, parser.singleNodeToString(node, stripLinebreak).string)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (parser *htmlParser) nodeToString(node *html.Node, stripLinebreak bool) string {
|
||||
return strings.Join(parser.nodeToStrings(node, stripLinebreak), "")
|
||||
}
|
||||
|
||||
func (parser *htmlParser) Parse(htmlData string) string {
|
||||
node, _ := html.Parse(strings.NewReader(htmlData))
|
||||
return parser.nodeToTagAwareString(node, true)
|
||||
}
|
||||
|
||||
func HTMLToText(html string) string {
|
||||
html = strings.Replace(html, "\t", " ", -1)
|
||||
str := (&htmlParser{}).Parse(html)
|
||||
return str
|
||||
}
|
191
matrix/matrix.go
191
matrix/matrix.go
@ -1,191 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"maubot.xyz"
|
||||
"maubot.xyz/database"
|
||||
"maunium.net/go/gomatrix"
|
||||
"maunium.net/go/gomatrix/format"
|
||||
log "maunium.net/go/maulogger"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
*gomatrix.Client
|
||||
syncer *MaubotSyncer
|
||||
handlers map[string][]maubot.CommandHandler
|
||||
commands []*ParsedCommand
|
||||
DB *database.MatrixClient
|
||||
}
|
||||
|
||||
func NewClient(db *database.MatrixClient) (*Client, error) {
|
||||
mxClient, err := gomatrix.NewClient(db.Homeserver, db.UserID, db.AccessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
Client: mxClient,
|
||||
handlers: make(map[string][]maubot.CommandHandler),
|
||||
commands: ParseSpec(db.Commands()),
|
||||
DB: db,
|
||||
}
|
||||
|
||||
client.syncer = NewMaubotSyncer(client, client.Store)
|
||||
client.Client.Syncer = client.syncer
|
||||
|
||||
client.AddEventHandler(gomatrix.StateMember, client.onJoin)
|
||||
client.AddEventHandler(gomatrix.EventMessage, client.onMessage)
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (client *Client) Proxy(owner string) *ClientProxy {
|
||||
return &ClientProxy{
|
||||
hiddenClient: client,
|
||||
owner: owner,
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) AddEventHandler(evt gomatrix.EventType, handler maubot.EventHandler) {
|
||||
client.syncer.OnEventType(evt, func(evt *maubot.Event) maubot.EventHandlerResult {
|
||||
if evt.Sender == client.UserID {
|
||||
return maubot.StopEventPropagation
|
||||
}
|
||||
return handler(evt)
|
||||
})
|
||||
}
|
||||
|
||||
func (client *Client) AddCommandHandler(owner, evt string, handler maubot.CommandHandler) {
|
||||
log.Debugln("Registering command handler for event", evt, "by", owner)
|
||||
list, ok := client.handlers[evt]
|
||||
if !ok {
|
||||
list = []maubot.CommandHandler{handler}
|
||||
} else {
|
||||
list = append(list, handler)
|
||||
}
|
||||
client.handlers[evt] = list
|
||||
}
|
||||
|
||||
func (client *Client) SetCommandSpec(owner string, spec *maubot.CommandSpec) {
|
||||
log.Debugln("Registering command spec for", owner, "on", client.UserID)
|
||||
changed := client.DB.SetCommandSpec(owner, spec)
|
||||
if changed {
|
||||
client.commands = ParseSpec(client.DB.Commands())
|
||||
log.Debugln("Command spec of", owner, "on", client.UserID, "updated.")
|
||||
}
|
||||
}
|
||||
|
||||
func (client *Client) GetEvent(roomID, eventID string) *maubot.Event {
|
||||
evt, err := client.Client.GetEvent(roomID, eventID)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to get event %s @ %s: %v\n", eventID, roomID, err)
|
||||
return nil
|
||||
}
|
||||
return client.ParseEvent(evt)
|
||||
}
|
||||
|
||||
func (client *Client) TriggerCommand(command *ParsedCommand, evt *maubot.Event) maubot.CommandHandlerResult {
|
||||
handlers, ok := client.handlers[command.Name]
|
||||
if !ok {
|
||||
log.Warnf("Command `%s` triggered by %s doesn't have any handlers.\n", command.Name, evt.Sender)
|
||||
return maubot.Continue
|
||||
}
|
||||
|
||||
log.Debugf("Command `%s` on client %s triggered by %s\n", command.Name, client.UserID, evt.Sender)
|
||||
for _, handler := range handlers {
|
||||
result := handler(evt)
|
||||
if result == maubot.StopCommandPropagation {
|
||||
break
|
||||
} else if result != maubot.Continue {
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
return maubot.Continue
|
||||
}
|
||||
|
||||
func (client *Client) onMessage(evt *maubot.Event) maubot.EventHandlerResult {
|
||||
for _, command := range client.commands {
|
||||
if command.Match(evt.Event) {
|
||||
return client.TriggerCommand(command, evt)
|
||||
}
|
||||
}
|
||||
return maubot.Continue
|
||||
}
|
||||
|
||||
func (client *Client) onJoin(evt *maubot.Event) maubot.EventHandlerResult {
|
||||
if client.DB.AutoJoinRooms && evt.GetStateKey() == client.DB.UserID && evt.Content.Membership == "invite" {
|
||||
client.JoinRoom(evt.RoomID)
|
||||
return maubot.StopEventPropagation
|
||||
}
|
||||
return maubot.Continue
|
||||
}
|
||||
|
||||
func (client *Client) JoinRoom(roomID string) (resp *gomatrix.RespJoinRoom, err error) {
|
||||
return client.Client.JoinRoom(roomID, "", nil)
|
||||
}
|
||||
|
||||
func (client *Client) SendMessage(roomID, text string) (string, error) {
|
||||
content := format.RenderMarkdown(text)
|
||||
content.MsgType = gomatrix.MsgNotice
|
||||
return client.SendContent(roomID, content)
|
||||
}
|
||||
|
||||
func (client *Client) SendMessagef(roomID, text string, args ...interface{}) (string, error) {
|
||||
content := format.RenderMarkdown(fmt.Sprintf(text, args...))
|
||||
content.MsgType = gomatrix.MsgNotice
|
||||
return client.SendContent(roomID, content)
|
||||
}
|
||||
|
||||
func (client *Client) SendContent(roomID string, content gomatrix.Content) (string, error) {
|
||||
return client.SendMessageEvent(roomID, gomatrix.EventMessage, content)
|
||||
}
|
||||
|
||||
func (client *Client) SendMessageEvent(roomID string, evtType gomatrix.EventType, content interface{}) (string, error) {
|
||||
resp, err := client.Client.SendMessageEvent(roomID, evtType, content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp.EventID, nil
|
||||
}
|
||||
|
||||
func (client *Client) Sync() {
|
||||
go func() {
|
||||
err := client.Client.Sync()
|
||||
if err != nil {
|
||||
log.Errorln("Sync() in client", client.UserID, "errored:", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
type hiddenClient = Client
|
||||
|
||||
type ClientProxy struct {
|
||||
*hiddenClient
|
||||
owner string
|
||||
}
|
||||
|
||||
func (cp *ClientProxy) AddCommandHandler(evt string, handler maubot.CommandHandler) {
|
||||
cp.hiddenClient.AddCommandHandler(cp.owner, evt, handler)
|
||||
}
|
||||
|
||||
func (cp *ClientProxy) SetCommandSpec(spec *maubot.CommandSpec) {
|
||||
cp.hiddenClient.SetCommandSpec(cp.owner, spec)
|
||||
}
|
147
matrix/sync.go
147
matrix/sync.go
@ -1,147 +0,0 @@
|
||||
package matrix
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"maubot.xyz"
|
||||
"maunium.net/go/gomatrix"
|
||||
)
|
||||
|
||||
type MaubotSyncer struct {
|
||||
Client *Client
|
||||
Store gomatrix.Storer
|
||||
listeners map[gomatrix.EventType][]maubot.EventHandler
|
||||
}
|
||||
|
||||
// NewDefaultSyncer returns an instantiated DefaultSyncer
|
||||
func NewMaubotSyncer(client *Client, store gomatrix.Storer) *MaubotSyncer {
|
||||
return &MaubotSyncer{
|
||||
Client: client,
|
||||
Store: store,
|
||||
listeners: make(map[gomatrix.EventType][]maubot.EventHandler),
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessResponse processes the /sync response in a way suitable for bots. "Suitable for bots" means a stream of
|
||||
// unrepeating events. Returns a fatal error if a listener panics.
|
||||
func (s *MaubotSyncer) ProcessResponse(res *gomatrix.RespSync, since string) (err error) {
|
||||
if !s.shouldProcessResponse(res, since) {
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("ProcessResponse panicked! userID=%s since=%s panic=%s\n%s", s.Client.UserID, since, r, debug.Stack())
|
||||
}
|
||||
}()
|
||||
|
||||
for roomID, roomData := range res.Rooms.Join {
|
||||
room := s.getOrCreateRoom(roomID)
|
||||
for _, event := range roomData.State.Events {
|
||||
event.RoomID = roomID
|
||||
room.UpdateState(event)
|
||||
s.notifyListeners(event)
|
||||
}
|
||||
for _, event := range roomData.Timeline.Events {
|
||||
event.RoomID = roomID
|
||||
s.notifyListeners(event)
|
||||
}
|
||||
}
|
||||
for roomID, roomData := range res.Rooms.Invite {
|
||||
room := s.getOrCreateRoom(roomID)
|
||||
for _, event := range roomData.State.Events {
|
||||
event.RoomID = roomID
|
||||
room.UpdateState(event)
|
||||
s.notifyListeners(event)
|
||||
}
|
||||
}
|
||||
for roomID, roomData := range res.Rooms.Leave {
|
||||
room := s.getOrCreateRoom(roomID)
|
||||
for _, event := range roomData.Timeline.Events {
|
||||
if event.StateKey != nil {
|
||||
event.RoomID = roomID
|
||||
room.UpdateState(event)
|
||||
s.notifyListeners(event)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// OnEventType allows callers to be notified when there are new events for the given event type.
|
||||
// There are no duplicate checks.
|
||||
func (s *MaubotSyncer) OnEventType(eventType gomatrix.EventType, callback maubot.EventHandler) {
|
||||
_, exists := s.listeners[eventType]
|
||||
if !exists {
|
||||
s.listeners[eventType] = []maubot.EventHandler{}
|
||||
}
|
||||
s.listeners[eventType] = append(s.listeners[eventType], callback)
|
||||
}
|
||||
|
||||
// shouldProcessResponse returns true if the response should be processed. May modify the response to remove
|
||||
// stuff that shouldn't be processed.
|
||||
func (s *MaubotSyncer) shouldProcessResponse(resp *gomatrix.RespSync, since string) bool {
|
||||
if since == "" {
|
||||
return false
|
||||
}
|
||||
// This is a horrible hack because /sync will return the most recent messages for a room
|
||||
// as soon as you /join it. We do NOT want to process those events in that particular room
|
||||
// because they may have already been processed (if you toggle the bot in/out of the room).
|
||||
//
|
||||
// Work around this by inspecting each room's timeline and seeing if an m.room.member event for us
|
||||
// exists and is "join" and then discard processing that room entirely if so.
|
||||
// TODO: We probably want to process messages from after the last join event in the timeline.
|
||||
for roomID, roomData := range resp.Rooms.Join {
|
||||
for i := len(roomData.Timeline.Events) - 1; i >= 0; i-- {
|
||||
evt := roomData.Timeline.Events[i]
|
||||
if evt.Type == gomatrix.StateMember && evt.GetStateKey() == s.Client.UserID {
|
||||
if evt.Content.Membership == gomatrix.MembershipJoin {
|
||||
_, ok := resp.Rooms.Join[roomID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
delete(resp.Rooms.Join, roomID) // don't re-process messages
|
||||
delete(resp.Rooms.Invite, roomID) // don't re-process invites
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// getOrCreateRoom must only be called by the Sync() goroutine which calls ProcessResponse()
|
||||
func (s *MaubotSyncer) getOrCreateRoom(roomID string) *gomatrix.Room {
|
||||
room := s.Store.LoadRoom(roomID)
|
||||
if room == nil {
|
||||
room = gomatrix.NewRoom(roomID)
|
||||
s.Store.SaveRoom(room)
|
||||
}
|
||||
return room
|
||||
}
|
||||
|
||||
func (s *MaubotSyncer) notifyListeners(mxEvent *gomatrix.Event) {
|
||||
event := s.Client.ParseEvent(mxEvent)
|
||||
listeners, exists := s.listeners[event.Type]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
for _, fn := range listeners {
|
||||
if fn(event) == maubot.StopEventPropagation {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OnFailedSync always returns a 10 second wait period between failed /syncs, never a fatal error.
|
||||
func (s *MaubotSyncer) OnFailedSync(res *gomatrix.RespSync, err error) (time.Duration, error) {
|
||||
return 10 * time.Second, nil
|
||||
}
|
||||
|
||||
// GetFilterJSON returns a filter with a timeline limit of 50.
|
||||
func (s *MaubotSyncer) GetFilterJSON(userID string) json.RawMessage {
|
||||
return json.RawMessage(`{"room":{"timeline":{"limit":50}}}`)
|
||||
}
|
3
maubot/__init__.py
Normal file
3
maubot/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .plugin_base import Plugin
|
||||
from .command_spec import CommandSpec, Command, PassiveCommand, Argument
|
||||
from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent
|
78
maubot/__main__.py
Normal file
78
maubot/__main__.py
Normal file
@ -0,0 +1,78 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2018 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from sqlalchemy import orm
|
||||
import sqlalchemy as sql
|
||||
import logging.config
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
import sys
|
||||
|
||||
from .config import Config
|
||||
from .db import Base, init as init_db
|
||||
from .server import MaubotServer
|
||||
from .client import Client, init as init_client
|
||||
from .loader import ZippedPluginLoader
|
||||
from .plugin import PluginInstance
|
||||
from .__meta__ import __version__
|
||||
|
||||
parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.",
|
||||
prog="python -m maubot")
|
||||
parser.add_argument("-c", "--config", type=str, default="config.yaml",
|
||||
metavar="<path>", help="the path to your config file")
|
||||
parser.add_argument("-b", "--base-config", type=str, default="example-config.yaml",
|
||||
metavar="<path>", help="the path to the example config "
|
||||
"(for automatic config updates)")
|
||||
args = parser.parse_args()
|
||||
|
||||
config = Config(args.config, args.base_config)
|
||||
config.load()
|
||||
config.update()
|
||||
|
||||
logging.config.dictConfig(copy.deepcopy(config["logging"]))
|
||||
log = logging.getLogger("maubot.init")
|
||||
log.debug(f"Initializing maubot {__version__}")
|
||||
|
||||
db_engine: sql.engine.Engine = sql.create_engine(config["database"])
|
||||
db_factory = orm.sessionmaker(bind=db_engine)
|
||||
db_session = orm.scoping.scoped_session(db_factory)
|
||||
Base.metadata.bind = db_engine
|
||||
Base.metadata.create_all()
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
init_db(db_session)
|
||||
init_client(loop)
|
||||
server = MaubotServer(config, loop)
|
||||
ZippedPluginLoader.load_all(*config["plugin_directories"])
|
||||
plugins = PluginInstance.all()
|
||||
|
||||
for plugin in plugins:
|
||||
plugin.load()
|
||||
|
||||
try:
|
||||
loop.run_until_complete(asyncio.gather(
|
||||
server.start(),
|
||||
*[plugin.start() for plugin in plugins]))
|
||||
log.debug("Startup actions complete, running forever.")
|
||||
loop.run_forever()
|
||||
except KeyboardInterrupt:
|
||||
log.debug("Keyboard interrupt received, stopping...")
|
||||
for client in Client.cache.values():
|
||||
client.stop()
|
||||
db_session.commit()
|
||||
loop.run_until_complete(server.stop())
|
||||
sys.exit(0)
|
1
maubot/__meta__.py
Normal file
1
maubot/__meta__.py
Normal file
@ -0,0 +1 @@
|
||||
__version__ = "0.1.0.dev1"
|
162
maubot/client.py
Normal file
162
maubot/client.py
Normal file
@ -0,0 +1,162 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2018 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, List, Optional
|
||||
from aiohttp import ClientSession
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
|
||||
EventType)
|
||||
|
||||
from .db import DBClient
|
||||
from .matrix import MaubotMatrixClient
|
||||
|
||||
log = logging.getLogger("maubot.client")
|
||||
|
||||
|
||||
class Client:
|
||||
loop: asyncio.AbstractEventLoop
|
||||
cache: Dict[UserID, 'Client'] = {}
|
||||
http_client: ClientSession = None
|
||||
|
||||
db_instance: DBClient
|
||||
client: MaubotMatrixClient
|
||||
|
||||
def __init__(self, db_instance: DBClient) -> None:
|
||||
self.db_instance = db_instance
|
||||
self.cache[self.id] = self
|
||||
self.log = log.getChild(self.id)
|
||||
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
|
||||
token=self.access_token, client_session=self.http_client,
|
||||
log=self.log, loop=self.loop, store=self.db_instance)
|
||||
if self.autojoin:
|
||||
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||
|
||||
def start(self) -> None:
|
||||
asyncio.ensure_future(self._start(), loop=self.loop)
|
||||
|
||||
async def _start(self) -> None:
|
||||
try:
|
||||
if self.displayname != "disable":
|
||||
await self.client.set_displayname(self.displayname)
|
||||
if self.avatar_url != "disable":
|
||||
await self.client.set_avatar_url(self.avatar_url)
|
||||
await self.client.start()
|
||||
except Exception:
|
||||
self.log.exception("starting raised exception")
|
||||
|
||||
def stop(self) -> None:
|
||||
self.client.stop()
|
||||
|
||||
@classmethod
|
||||
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
|
||||
try:
|
||||
return cls.cache[user_id]
|
||||
except KeyError:
|
||||
db_instance = db_instance or DBClient.query.get(user_id)
|
||||
if not db_instance:
|
||||
return None
|
||||
return Client(db_instance)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> List['Client']:
|
||||
return [cls.get(user.id, user) for user in DBClient.query.all()]
|
||||
|
||||
async def _handle_invite(self, evt: StrippedStateEvent) -> None:
|
||||
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
|
||||
await self.client.join_room(evt.room_id)
|
||||
|
||||
# region Properties
|
||||
|
||||
@property
|
||||
def id(self) -> UserID:
|
||||
return self.db_instance.id
|
||||
|
||||
@property
|
||||
def homeserver(self) -> str:
|
||||
return self.db_instance.homeserver
|
||||
|
||||
@property
|
||||
def access_token(self) -> str:
|
||||
return self.db_instance.access_token
|
||||
|
||||
@access_token.setter
|
||||
def access_token(self, value: str) -> None:
|
||||
self.client.api.token = value
|
||||
self.db_instance.access_token = value
|
||||
|
||||
@property
|
||||
def next_batch(self) -> SyncToken:
|
||||
return self.db_instance.next_batch
|
||||
|
||||
@next_batch.setter
|
||||
def next_batch(self, value: SyncToken) -> None:
|
||||
self.db_instance.next_batch = value
|
||||
|
||||
@property
|
||||
def filter_id(self) -> FilterID:
|
||||
return self.db_instance.filter_id
|
||||
|
||||
@filter_id.setter
|
||||
def filter_id(self, value: FilterID) -> None:
|
||||
self.db_instance.filter_id = value
|
||||
|
||||
@property
|
||||
def sync(self) -> bool:
|
||||
return self.db_instance.sync
|
||||
|
||||
@sync.setter
|
||||
def sync(self, value: bool) -> None:
|
||||
self.db_instance.sync = value
|
||||
|
||||
@property
|
||||
def autojoin(self) -> bool:
|
||||
return self.db_instance.autojoin
|
||||
|
||||
@autojoin.setter
|
||||
def autojoin(self, value: bool) -> None:
|
||||
if value == self.db_instance.autojoin:
|
||||
return
|
||||
if value:
|
||||
self.client.add_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||
else:
|
||||
self.client.remove_event_handler(self._handle_invite, EventType.ROOM_MEMBER)
|
||||
self.db_instance.autojoin = value
|
||||
|
||||
@property
|
||||
def displayname(self) -> str:
|
||||
return self.db_instance.displayname
|
||||
|
||||
@displayname.setter
|
||||
def displayname(self, value: str) -> None:
|
||||
self.db_instance.displayname = value
|
||||
|
||||
@property
|
||||
def avatar_url(self) -> ContentURI:
|
||||
return self.db_instance.avatar_url
|
||||
|
||||
@avatar_url.setter
|
||||
def avatar_url(self, value: ContentURI) -> None:
|
||||
self.db_instance.avatar_url = value
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def init(loop: asyncio.AbstractEventLoop) -> None:
|
||||
Client.http_client = ClientSession(loop=loop)
|
||||
Client.loop = loop
|
||||
for client in Client.all():
|
||||
client.start()
|
152
maubot/command_spec.py
Normal file
152
maubot/command_spec.py
Normal file
@ -0,0 +1,152 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2018 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import List, Dict, Pattern, Union, Tuple, Optional, Any
|
||||
from attr import dataclass
|
||||
import re
|
||||
|
||||
from mautrix.types import MessageEvent, MatchedCommand, MatchedPassiveCommand
|
||||
from mautrix.client.api.types.util import SerializableAttrs
|
||||
|
||||
|
||||
@dataclass
|
||||
class Argument(SerializableAttrs['Argument']):
|
||||
matches: str
|
||||
required: bool = False
|
||||
description: str = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Command(SerializableAttrs['Command']):
|
||||
syntax: str
|
||||
arguments: Dict[str, Argument]
|
||||
description: str = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PassiveCommand(SerializableAttrs['PassiveCommand']):
|
||||
name: str
|
||||
matches: str
|
||||
match_against: str
|
||||
match_event: MessageEvent = None
|
||||
|
||||
|
||||
class ParsedCommand:
|
||||
name: str
|
||||
is_passive: bool
|
||||
arguments: List[str]
|
||||
starts_with: str
|
||||
matches: Pattern
|
||||
match_against: str
|
||||
match_event: MessageEvent
|
||||
|
||||
def __init__(self, command: Union[PassiveCommand, Command]) -> None:
|
||||
if isinstance(command, PassiveCommand):
|
||||
self._init_passive(command)
|
||||
elif isinstance(command, Command):
|
||||
self._init_active(command)
|
||||
else:
|
||||
raise ValueError("Command parameter must be a Command or a PassiveCommand.")
|
||||
|
||||
def _init_passive(self, command: PassiveCommand) -> None:
|
||||
self.name = command.name
|
||||
self.is_passive = True
|
||||
self.match_against = command.match_against
|
||||
self.matches = re.compile(command.matches)
|
||||
self.match_event = command.match_event
|
||||
|
||||
def _init_active(self, command: Command) -> None:
|
||||
self.name = command.syntax
|
||||
self.is_passive = False
|
||||
|
||||
regex_builder = []
|
||||
sw_builder = []
|
||||
argument_encountered = False
|
||||
|
||||
for word in command.syntax.split(" "):
|
||||
arg = command.arguments.get(word, None)
|
||||
if arg is not None and len(word) > 0:
|
||||
argument_encountered = True
|
||||
regex = f"({arg.matches})" if arg.required else f"(?:{arg.matches})?"
|
||||
self.arguments.append(word)
|
||||
regex_builder.append(regex)
|
||||
else:
|
||||
if not argument_encountered:
|
||||
sw_builder.append(word)
|
||||
regex_builder.append(re.escape(word))
|
||||
self.starts_with = "!" + " ".join(sw_builder)
|
||||
self.matches = re.compile("^!" + " ".join(regex_builder) + "$")
|
||||
self.match_against = "body"
|
||||
|
||||
def match(self, evt: MessageEvent) -> bool:
|
||||
return self._match_passive(evt) if self.is_passive else self._match_active(evt)
|
||||
|
||||
@staticmethod
|
||||
def _parse_key(key: str) -> Tuple[str, Optional[str]]:
|
||||
if '.' not in key:
|
||||
return key, None
|
||||
key, next_key = key.split('.', 1)
|
||||
if len(key) > 0 and key[0] == "[":
|
||||
end_index = next_key.index("]")
|
||||
key = key[1:] + "." + next_key[:end_index]
|
||||
next_key = next_key[end_index + 2:] if len(next_key) > end_index + 1 else None
|
||||
return key, next_key
|
||||
|
||||
@classmethod
|
||||
def _recursive_get(cls, data: Any, key: str) -> Any:
|
||||
if not data:
|
||||
return None
|
||||
key, next_key = cls._parse_key(key)
|
||||
if next_key is not None:
|
||||
return cls._recursive_get(data[key], next_key)
|
||||
return data[key]
|
||||
|
||||
def _match_passive(self, evt: MessageEvent) -> bool:
|
||||
try:
|
||||
match_against = self._recursive_get(evt.content, self.match_against)
|
||||
except KeyError:
|
||||
match_against = None
|
||||
match_against = match_against or evt.content.body
|
||||
matches = [[match.string[match.start():match.end()]] + list(match.groups())
|
||||
for match in self.matches.finditer(match_against)]
|
||||
if not matches:
|
||||
return False
|
||||
if evt.unsigned.passive_command is None:
|
||||
evt.unsigned.passive_command = {}
|
||||
evt.unsigned.passive_command[self.name] = MatchedPassiveCommand(captured=matches)
|
||||
return True
|
||||
|
||||
def _match_active(self, evt: MessageEvent) -> bool:
|
||||
if not evt.content.body.startswith(self.starts_with):
|
||||
return False
|
||||
match = self.matches.match(evt.content.body)
|
||||
if not match:
|
||||
return False
|
||||
evt.content.command = MatchedCommand(matched=self.name,
|
||||
arguments=dict(zip(self.arguments, match.groups())))
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandSpec(SerializableAttrs['CommandSpec']):
|
||||
commands: List[Command] = []
|
||||
passive_commands: List[PassiveCommand] = []
|
||||
|
||||
def __add__(self, other: 'CommandSpec') -> 'CommandSpec':
|
||||
return CommandSpec(commands=self.commands + other.commands,
|
||||
passive_commands=self.passive_commands + other.passive_commands)
|
||||
|
||||
def parse(self) -> List[ParsedCommand]:
|
||||
return [ParsedCommand(command) for command in self.commands + self.passive_commands]
|
40
maubot/config.py
Normal file
40
maubot/config.py
Normal file
@ -0,0 +1,40 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2018 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
import random
|
||||
import string
|
||||
|
||||
from mautrix.util import BaseConfig
|
||||
|
||||
|
||||
class Config(BaseConfig):
|
||||
@staticmethod
|
||||
def _new_token() -> str:
|
||||
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(64))
|
||||
|
||||
def update(self):
|
||||
base, copy, copy_dict = self._pre_update()
|
||||
copy("database")
|
||||
copy("plugin_directories")
|
||||
copy("server.hostname")
|
||||
copy("server.port")
|
||||
copy("server.listen")
|
||||
copy("server.base_path")
|
||||
shared_secret = self["server.shared_secret"]
|
||||
if shared_secret is None or shared_secret == "generate":
|
||||
base["server.shared_secret"] = self._new_token()
|
||||
else:
|
||||
base["server.shared_secret"] = shared_secret
|
||||
copy("logging")
|
97
maubot/db.py
Normal file
97
maubot/db.py
Normal file
@ -0,0 +1,97 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2018 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Type
|
||||
from sqlalchemy import (Column, String, Boolean, ForeignKey, Text, TypeDecorator)
|
||||
from sqlalchemy.orm import Query, scoped_session
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
import json
|
||||
|
||||
from mautrix.types import UserID, FilterID, SyncToken, ContentURI
|
||||
from mautrix.client.api.types.util import Serializable
|
||||
from mautrix import ClientStore
|
||||
|
||||
from .command_spec import CommandSpec
|
||||
|
||||
Base: declarative_base = declarative_base()
|
||||
|
||||
|
||||
def make_serializable_alchemy(serializable_type: Type[Serializable]):
|
||||
class SerializableAlchemy(TypeDecorator):
|
||||
impl = Text
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return serializable_type
|
||||
|
||||
def process_literal_param(self, value: Serializable, _) -> str:
|
||||
return json.dumps(value.serialize()) if value is not None else None
|
||||
|
||||
def process_bind_param(self, value: Serializable, _) -> str:
|
||||
return json.dumps(value.serialize()) if value is not None else None
|
||||
|
||||
def process_result_value(self, value: str, _) -> serializable_type:
|
||||
return serializable_type.deserialize(json.loads(value)) if value is not None else None
|
||||
|
||||
return SerializableAlchemy
|
||||
|
||||
|
||||
class DBPlugin(Base):
|
||||
query: Query
|
||||
__tablename__ = "plugin"
|
||||
|
||||
id: str = Column(String(255), primary_key=True)
|
||||
type: str = Column(String(255), nullable=False)
|
||||
enabled: bool = Column(Boolean, nullable=False, default=False)
|
||||
primary_user: UserID = Column(String(255),
|
||||
ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"),
|
||||
nullable=False)
|
||||
|
||||
|
||||
class DBClient(Base):
|
||||
query: Query
|
||||
__tablename__ = "client"
|
||||
|
||||
id: UserID = Column(String(255), primary_key=True)
|
||||
homeserver: str = Column(String(255), nullable=False)
|
||||
access_token: str = Column(String(255), nullable=False)
|
||||
|
||||
next_batch: SyncToken = Column(String(255), nullable=False, default="")
|
||||
filter_id: FilterID = Column(String(255), nullable=False, default="")
|
||||
|
||||
sync: bool = Column(Boolean, nullable=False, default=True)
|
||||
autojoin: bool = Column(Boolean, nullable=False, default=True)
|
||||
|
||||
displayname: str = Column(String(255), nullable=False, default="")
|
||||
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
|
||||
|
||||
|
||||
class DBCommandSpec(Base):
|
||||
query: Query
|
||||
__tablename__ = "command_spec"
|
||||
|
||||
plugin: str = Column(String(255),
|
||||
ForeignKey("plugin.id", onupdate="CASCADE", ondelete="CASCADE"),
|
||||
primary_key=True)
|
||||
client: UserID = Column(String(255),
|
||||
ForeignKey("client.id", onupdate="CASCADE", ondelete="CASCADE"),
|
||||
primary_key=True)
|
||||
spec: CommandSpec = Column(make_serializable_alchemy(CommandSpec), nullable=False)
|
||||
|
||||
|
||||
def init(session: scoped_session) -> None:
|
||||
DBPlugin.query = session.query_property()
|
||||
DBClient.query = session.query_property()
|
||||
DBCommandSpec.query = session.query_property()
|
778
maubot/lib/zipimport.py
Normal file
778
maubot/lib/zipimport.py
Normal file
@ -0,0 +1,778 @@
|
||||
# The pure Python implementation of zipimport in Python 3.8+. Slightly modified to allow clearing
|
||||
# the zip directory cache to bypass https://bugs.python.org/issue19081
|
||||
#
|
||||
# https://github.com/python/cpython/blob/5a5ce064b3baadcb79605c5a42ee3d0aee57cdfc/Lib/zipimport.py
|
||||
# See license at https://github.com/python/cpython/blob/master/LICENSE
|
||||
|
||||
"""zipimport provides support for importing Python modules from Zip archives.
|
||||
|
||||
This module exports three objects:
|
||||
- zipimporter: a class; its constructor takes a path to a Zip archive.
|
||||
- ZipImportError: exception raised by zipimporter objects. It's a
|
||||
subclass of ImportError, so it can be caught as ImportError, too.
|
||||
- _zip_directory_cache: a dict, mapping archive paths to zip directory
|
||||
info dicts, as used in zipimporter._files.
|
||||
|
||||
It is usually not needed to use the zipimport module explicitly; it is
|
||||
used by the builtin import mechanism for sys.path items that are paths
|
||||
to Zip archives.
|
||||
"""
|
||||
|
||||
from importlib import _bootstrap_external
|
||||
from importlib import _bootstrap # for _verbose_message
|
||||
import _imp # for check_hash_based_pycs
|
||||
import _io # for open
|
||||
import marshal # for loads
|
||||
import sys # for modules
|
||||
import time # for mktime
|
||||
|
||||
__all__ = ['ZipImportError', 'zipimporter']
|
||||
|
||||
|
||||
def _unpack_uint32(data):
|
||||
"""Convert 4 bytes in little-endian to an integer."""
|
||||
assert len(data) == 4
|
||||
return int.from_bytes(data, 'little')
|
||||
|
||||
def _unpack_uint16(data):
|
||||
"""Convert 2 bytes in little-endian to an integer."""
|
||||
assert len(data) == 2
|
||||
return int.from_bytes(data, 'little')
|
||||
|
||||
|
||||
path_sep = _bootstrap_external.path_sep
|
||||
alt_path_sep = _bootstrap_external.path_separators[1:]
|
||||
|
||||
|
||||
class ZipImportError(ImportError):
|
||||
pass
|
||||
|
||||
# _read_directory() cache
|
||||
_zip_directory_cache = {}
|
||||
|
||||
_module_type = type(sys)
|
||||
|
||||
END_CENTRAL_DIR_SIZE = 22
|
||||
STRING_END_ARCHIVE = b'PK\x05\x06'
|
||||
MAX_COMMENT_LEN = (1 << 16) - 1
|
||||
|
||||
class zipimporter:
|
||||
"""zipimporter(archivepath) -> zipimporter object
|
||||
|
||||
Create a new zipimporter instance. 'archivepath' must be a path to
|
||||
a zipfile, or to a specific path inside a zipfile. For example, it can be
|
||||
'/tmp/myimport.zip', or '/tmp/myimport.zip/mydirectory', if mydirectory is a
|
||||
valid directory inside the archive.
|
||||
|
||||
'ZipImportError is raised if 'archivepath' doesn't point to a valid Zip
|
||||
archive.
|
||||
|
||||
The 'archive' attribute of zipimporter objects contains the name of the
|
||||
zipfile targeted.
|
||||
"""
|
||||
|
||||
# Split the "subdirectory" from the Zip archive path, lookup a matching
|
||||
# entry in sys.path_importer_cache, fetch the file directory from there
|
||||
# if found, or else read it from the archive.
|
||||
def __init__(self, path):
|
||||
if not isinstance(path, str):
|
||||
import os
|
||||
path = os.fsdecode(path)
|
||||
if not path:
|
||||
raise ZipImportError('archive path is empty', path=path)
|
||||
if alt_path_sep:
|
||||
path = path.replace(alt_path_sep, path_sep)
|
||||
|
||||
prefix = []
|
||||
while True:
|
||||
try:
|
||||
st = _bootstrap_external._path_stat(path)
|
||||
except (OSError, ValueError):
|
||||
# On Windows a ValueError is raised for too long paths.
|
||||
# Back up one path element.
|
||||
dirname, basename = _bootstrap_external._path_split(path)
|
||||
if dirname == path:
|
||||
raise ZipImportError('not a Zip file', path=path)
|
||||
path = dirname
|
||||
prefix.append(basename)
|
||||
else:
|
||||
# it exists
|
||||
if (st.st_mode & 0o170000) != 0o100000: # stat.S_ISREG
|
||||
# it's a not file
|
||||
raise ZipImportError('not a Zip file', path=path)
|
||||
break
|
||||
|
||||
try:
|
||||
files = _zip_directory_cache[path]
|
||||
except KeyError:
|
||||
files = _read_directory(path)
|
||||
_zip_directory_cache[path] = files
|
||||
self._files = files
|
||||
self.archive = path
|
||||
# a prefix directory following the ZIP file path.
|
||||
self.prefix = _bootstrap_external._path_join(*prefix[::-1])
|
||||
if self.prefix:
|
||||
self.prefix += path_sep
|
||||
|
||||
def reset_cache(self):
|
||||
self._files = _read_directory(self.archive)
|
||||
_zip_directory_cache[self.archive] = self._files
|
||||
|
||||
def remove_cache(self):
|
||||
try:
|
||||
del _zip_directory_cache[self.archive]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Check whether we can satisfy the import of the module named by
|
||||
# 'fullname', or whether it could be a portion of a namespace
|
||||
# package. Return self if we can load it, a string containing the
|
||||
# full path if it's a possible namespace portion, None if we
|
||||
# can't load it.
|
||||
def find_loader(self, fullname, path=None):
|
||||
"""find_loader(fullname, path=None) -> self, str or None.
|
||||
|
||||
Search for a module specified by 'fullname'. 'fullname' must be the
|
||||
fully qualified (dotted) module name. It returns the zipimporter
|
||||
instance itself if the module was found, a string containing the
|
||||
full path name if it's possibly a portion of a namespace package,
|
||||
or None otherwise. The optional 'path' argument is ignored -- it's
|
||||
there for compatibility with the importer protocol.
|
||||
"""
|
||||
mi = _get_module_info(self, fullname)
|
||||
if mi is not None:
|
||||
# This is a module or package.
|
||||
return self, []
|
||||
|
||||
# Not a module or regular package. See if this is a directory, and
|
||||
# therefore possibly a portion of a namespace package.
|
||||
|
||||
# We're only interested in the last path component of fullname
|
||||
# earlier components are recorded in self.prefix.
|
||||
modpath = _get_module_path(self, fullname)
|
||||
if _is_dir(self, modpath):
|
||||
# This is possibly a portion of a namespace
|
||||
# package. Return the string representing its path,
|
||||
# without a trailing separator.
|
||||
return None, [f'{self.archive}{path_sep}{modpath}']
|
||||
|
||||
return None, []
|
||||
|
||||
|
||||
# Check whether we can satisfy the import of the module named by
|
||||
# 'fullname'. Return self if we can, None if we can't.
|
||||
def find_module(self, fullname, path=None):
|
||||
"""find_module(fullname, path=None) -> self or None.
|
||||
|
||||
Search for a module specified by 'fullname'. 'fullname' must be the
|
||||
fully qualified (dotted) module name. It returns the zipimporter
|
||||
instance itself if the module was found, or None if it wasn't.
|
||||
The optional 'path' argument is ignored -- it's there for compatibility
|
||||
with the importer protocol.
|
||||
"""
|
||||
return self.find_loader(fullname, path)[0]
|
||||
|
||||
|
||||
def get_code(self, fullname):
|
||||
"""get_code(fullname) -> code object.
|
||||
|
||||
Return the code object for the specified module. Raise ZipImportError
|
||||
if the module couldn't be found.
|
||||
"""
|
||||
code, ispackage, modpath = _get_module_code(self, fullname)
|
||||
return code
|
||||
|
||||
|
||||
def get_data(self, pathname):
|
||||
"""get_data(pathname) -> string with file data.
|
||||
|
||||
Return the data associated with 'pathname'. Raise OSError if
|
||||
the file wasn't found.
|
||||
"""
|
||||
if alt_path_sep:
|
||||
pathname = pathname.replace(alt_path_sep, path_sep)
|
||||
|
||||
key = pathname
|
||||
if pathname.startswith(self.archive + path_sep):
|
||||
key = pathname[len(self.archive + path_sep):]
|
||||
|
||||
try:
|
||||
toc_entry = self._files[key]
|
||||
except KeyError:
|
||||
raise OSError(0, '', key)
|
||||
return _get_data(self.archive, toc_entry)
|
||||
|
||||
|
||||
# Return a string matching __file__ for the named module
|
||||
def get_filename(self, fullname):
|
||||
"""get_filename(fullname) -> filename string.
|
||||
|
||||
Return the filename for the specified module.
|
||||
"""
|
||||
# Deciding the filename requires working out where the code
|
||||
# would come from if the module was actually loaded
|
||||
code, ispackage, modpath = _get_module_code(self, fullname)
|
||||
return modpath
|
||||
|
||||
|
||||
def get_source(self, fullname):
|
||||
"""get_source(fullname) -> source string.
|
||||
|
||||
Return the source code for the specified module. Raise ZipImportError
|
||||
if the module couldn't be found, return None if the archive does
|
||||
contain the module, but has no source for it.
|
||||
"""
|
||||
mi = _get_module_info(self, fullname)
|
||||
if mi is None:
|
||||
raise ZipImportError(f"can't find module {fullname!r}", name=fullname)
|
||||
|
||||
path = _get_module_path(self, fullname)
|
||||
if mi:
|
||||
fullpath = _bootstrap_external._path_join(path, '__init__.py')
|
||||
else:
|
||||
fullpath = f'{path}.py'
|
||||
|
||||
try:
|
||||
toc_entry = self._files[fullpath]
|
||||
except KeyError:
|
||||
# we have the module, but no source
|
||||
return None
|
||||
return _get_data(self.archive, toc_entry).decode()
|
||||
|
||||
|
||||
# Return a bool signifying whether the module is a package or not.
|
||||
def is_package(self, fullname):
|
||||
"""is_package(fullname) -> bool.
|
||||
|
||||
Return True if the module specified by fullname is a package.
|
||||
Raise ZipImportError if the module couldn't be found.
|
||||
"""
|
||||
mi = _get_module_info(self, fullname)
|
||||
if mi is None:
|
||||
raise ZipImportError(f"can't find module {fullname!r}", name=fullname)
|
||||
return mi
|
||||
|
||||
|
||||
# Load and return the module named by 'fullname'.
|
||||
def load_module(self, fullname):
|
||||
"""load_module(fullname) -> module.
|
||||
|
||||
Load the module specified by 'fullname'. 'fullname' must be the
|
||||
fully qualified (dotted) module name. It returns the imported
|
||||
module, or raises ZipImportError if it wasn't found.
|
||||
"""
|
||||
code, ispackage, modpath = _get_module_code(self, fullname)
|
||||
mod = sys.modules.get(fullname)
|
||||
if mod is None or not isinstance(mod, _module_type):
|
||||
mod = _module_type(fullname)
|
||||
sys.modules[fullname] = mod
|
||||
mod.__loader__ = self
|
||||
|
||||
try:
|
||||
if ispackage:
|
||||
# add __path__ to the module *before* the code gets
|
||||
# executed
|
||||
path = _get_module_path(self, fullname)
|
||||
fullpath = _bootstrap_external._path_join(self.archive, path)
|
||||
mod.__path__ = [fullpath]
|
||||
|
||||
if not hasattr(mod, '__builtins__'):
|
||||
mod.__builtins__ = __builtins__
|
||||
_bootstrap_external._fix_up_module(mod.__dict__, fullname, modpath)
|
||||
exec(code, mod.__dict__)
|
||||
except:
|
||||
del sys.modules[fullname]
|
||||
raise
|
||||
|
||||
try:
|
||||
mod = sys.modules[fullname]
|
||||
except KeyError:
|
||||
raise ImportError(f'Loaded module {fullname!r} not found in sys.modules')
|
||||
_bootstrap._verbose_message('import {} # loaded from Zip {}', fullname, modpath)
|
||||
return mod
|
||||
|
||||
|
||||
def get_resource_reader(self, fullname):
|
||||
"""Return the ResourceReader for a package in a zip file.
|
||||
|
||||
If 'fullname' is a package within the zip file, return the
|
||||
'ResourceReader' object for the package. Otherwise return None.
|
||||
"""
|
||||
try:
|
||||
if not self.is_package(fullname):
|
||||
return None
|
||||
except ZipImportError:
|
||||
return None
|
||||
if not _ZipImportResourceReader._registered:
|
||||
from importlib.abc import ResourceReader
|
||||
ResourceReader.register(_ZipImportResourceReader)
|
||||
_ZipImportResourceReader._registered = True
|
||||
return _ZipImportResourceReader(self, fullname)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return f'<zipimporter object "{self.archive}{path_sep}{self.prefix}">'
|
||||
|
||||
|
||||
# _zip_searchorder defines how we search for a module in the Zip
|
||||
# archive: we first search for a package __init__, then for
|
||||
# non-package .pyc, and .py entries. The .pyc entries
|
||||
# are swapped by initzipimport() if we run in optimized mode. Also,
|
||||
# '/' is replaced by path_sep there.
|
||||
_zip_searchorder = (
|
||||
(path_sep + '__init__.pyc', True, True),
|
||||
(path_sep + '__init__.py', False, True),
|
||||
('.pyc', True, False),
|
||||
('.py', False, False),
|
||||
)
|
||||
|
||||
# Given a module name, return the potential file path in the
|
||||
# archive (without extension).
|
||||
def _get_module_path(self, fullname):
|
||||
return self.prefix + fullname.rpartition('.')[2]
|
||||
|
||||
# Does this path represent a directory?
|
||||
def _is_dir(self, path):
|
||||
# See if this is a "directory". If so, it's eligible to be part
|
||||
# of a namespace package. We test by seeing if the name, with an
|
||||
# appended path separator, exists.
|
||||
dirpath = path + path_sep
|
||||
# If dirpath is present in self._files, we have a directory.
|
||||
return dirpath in self._files
|
||||
|
||||
# Return some information about a module.
|
||||
def _get_module_info(self, fullname):
|
||||
path = _get_module_path(self, fullname)
|
||||
for suffix, isbytecode, ispackage in _zip_searchorder:
|
||||
fullpath = path + suffix
|
||||
if fullpath in self._files:
|
||||
return ispackage
|
||||
return None
|
||||
|
||||
|
||||
# implementation
|
||||
|
||||
# _read_directory(archive) -> files dict (new reference)
|
||||
#
|
||||
# Given a path to a Zip archive, build a dict, mapping file names
|
||||
# (local to the archive, using SEP as a separator) to toc entries.
|
||||
#
|
||||
# A toc_entry is a tuple:
|
||||
#
|
||||
# (__file__, # value to use for __file__, available for all files,
|
||||
# # encoded to the filesystem encoding
|
||||
# compress, # compression kind; 0 for uncompressed
|
||||
# data_size, # size of compressed data on disk
|
||||
# file_size, # size of decompressed data
|
||||
# file_offset, # offset of file header from start of archive
|
||||
# time, # mod time of file (in dos format)
|
||||
# date, # mod data of file (in dos format)
|
||||
# crc, # crc checksum of the data
|
||||
# )
|
||||
#
|
||||
# Directories can be recognized by the trailing path_sep in the name,
|
||||
# data_size and file_offset are 0.
|
||||
def _read_directory(archive):
|
||||
try:
|
||||
fp = _io.open(archive, 'rb')
|
||||
except OSError:
|
||||
raise ZipImportError(f"can't open Zip file: {archive!r}", path=archive)
|
||||
|
||||
with fp:
|
||||
try:
|
||||
fp.seek(-END_CENTRAL_DIR_SIZE, 2)
|
||||
header_position = fp.tell()
|
||||
buffer = fp.read(END_CENTRAL_DIR_SIZE)
|
||||
except OSError:
|
||||
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
|
||||
if len(buffer) != END_CENTRAL_DIR_SIZE:
|
||||
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
|
||||
if buffer[:4] != STRING_END_ARCHIVE:
|
||||
# Bad: End of Central Dir signature
|
||||
# Check if there's a comment.
|
||||
try:
|
||||
fp.seek(0, 2)
|
||||
file_size = fp.tell()
|
||||
except OSError:
|
||||
raise ZipImportError(f"can't read Zip file: {archive!r}",
|
||||
path=archive)
|
||||
max_comment_start = max(file_size - MAX_COMMENT_LEN -
|
||||
END_CENTRAL_DIR_SIZE, 0)
|
||||
try:
|
||||
fp.seek(max_comment_start)
|
||||
data = fp.read()
|
||||
except OSError:
|
||||
raise ZipImportError(f"can't read Zip file: {archive!r}",
|
||||
path=archive)
|
||||
pos = data.rfind(STRING_END_ARCHIVE)
|
||||
if pos < 0:
|
||||
raise ZipImportError(f'not a Zip file: {archive!r}',
|
||||
path=archive)
|
||||
buffer = data[pos:pos+END_CENTRAL_DIR_SIZE]
|
||||
if len(buffer) != END_CENTRAL_DIR_SIZE:
|
||||
raise ZipImportError(f"corrupt Zip file: {archive!r}",
|
||||
path=archive)
|
||||
header_position = file_size - len(data) + pos
|
||||
|
||||
header_size = _unpack_uint32(buffer[12:16])
|
||||
header_offset = _unpack_uint32(buffer[16:20])
|
||||
if header_position < header_size:
|
||||
raise ZipImportError(f'bad central directory size: {archive!r}', path=archive)
|
||||
if header_position < header_offset:
|
||||
raise ZipImportError(f'bad central directory offset: {archive!r}', path=archive)
|
||||
header_position -= header_size
|
||||
arc_offset = header_position - header_offset
|
||||
if arc_offset < 0:
|
||||
raise ZipImportError(f'bad central directory size or offset: {archive!r}', path=archive)
|
||||
|
||||
files = {}
|
||||
# Start of Central Directory
|
||||
count = 0
|
||||
try:
|
||||
fp.seek(header_position)
|
||||
except OSError:
|
||||
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
|
||||
while True:
|
||||
buffer = fp.read(46)
|
||||
if len(buffer) < 4:
|
||||
raise EOFError('EOF read where not expected')
|
||||
# Start of file header
|
||||
if buffer[:4] != b'PK\x01\x02':
|
||||
break # Bad: Central Dir File Header
|
||||
if len(buffer) != 46:
|
||||
raise EOFError('EOF read where not expected')
|
||||
flags = _unpack_uint16(buffer[8:10])
|
||||
compress = _unpack_uint16(buffer[10:12])
|
||||
time = _unpack_uint16(buffer[12:14])
|
||||
date = _unpack_uint16(buffer[14:16])
|
||||
crc = _unpack_uint32(buffer[16:20])
|
||||
data_size = _unpack_uint32(buffer[20:24])
|
||||
file_size = _unpack_uint32(buffer[24:28])
|
||||
name_size = _unpack_uint16(buffer[28:30])
|
||||
extra_size = _unpack_uint16(buffer[30:32])
|
||||
comment_size = _unpack_uint16(buffer[32:34])
|
||||
file_offset = _unpack_uint32(buffer[42:46])
|
||||
header_size = name_size + extra_size + comment_size
|
||||
if file_offset > header_offset:
|
||||
raise ZipImportError(f'bad local header offset: {archive!r}', path=archive)
|
||||
file_offset += arc_offset
|
||||
|
||||
try:
|
||||
name = fp.read(name_size)
|
||||
except OSError:
|
||||
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
|
||||
if len(name) != name_size:
|
||||
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
|
||||
# On Windows, calling fseek to skip over the fields we don't use is
|
||||
# slower than reading the data because fseek flushes stdio's
|
||||
# internal buffers. See issue #8745.
|
||||
try:
|
||||
if len(fp.read(header_size - name_size)) != header_size - name_size:
|
||||
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
|
||||
except OSError:
|
||||
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
|
||||
|
||||
if flags & 0x800:
|
||||
# UTF-8 file names extension
|
||||
name = name.decode()
|
||||
else:
|
||||
# Historical ZIP filename encoding
|
||||
try:
|
||||
name = name.decode('ascii')
|
||||
except UnicodeDecodeError:
|
||||
name = name.decode('latin1').translate(cp437_table)
|
||||
|
||||
name = name.replace('/', path_sep)
|
||||
path = _bootstrap_external._path_join(archive, name)
|
||||
t = (path, compress, data_size, file_size, file_offset, time, date, crc)
|
||||
files[name] = t
|
||||
count += 1
|
||||
_bootstrap._verbose_message('zipimport: found {} names in {!r}', count, archive)
|
||||
return files
|
||||
|
||||
# During bootstrap, we may need to load the encodings
|
||||
# package from a ZIP file. But the cp437 encoding is implemented
|
||||
# in Python in the encodings package.
|
||||
#
|
||||
# Break out of this dependency by using the translation table for
|
||||
# the cp437 encoding.
|
||||
cp437_table = (
|
||||
# ASCII part, 8 rows x 16 chars
|
||||
'\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f'
|
||||
'\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f'
|
||||
' !"#$%&\'()*+,-./'
|
||||
'0123456789:;<=>?'
|
||||
'@ABCDEFGHIJKLMNO'
|
||||
'PQRSTUVWXYZ[\\]^_'
|
||||
'`abcdefghijklmno'
|
||||
'pqrstuvwxyz{|}~\x7f'
|
||||
# non-ASCII part, 16 rows x 8 chars
|
||||
'\xc7\xfc\xe9\xe2\xe4\xe0\xe5\xe7'
|
||||
'\xea\xeb\xe8\xef\xee\xec\xc4\xc5'
|
||||
'\xc9\xe6\xc6\xf4\xf6\xf2\xfb\xf9'
|
||||
'\xff\xd6\xdc\xa2\xa3\xa5\u20a7\u0192'
|
||||
'\xe1\xed\xf3\xfa\xf1\xd1\xaa\xba'
|
||||
'\xbf\u2310\xac\xbd\xbc\xa1\xab\xbb'
|
||||
'\u2591\u2592\u2593\u2502\u2524\u2561\u2562\u2556'
|
||||
'\u2555\u2563\u2551\u2557\u255d\u255c\u255b\u2510'
|
||||
'\u2514\u2534\u252c\u251c\u2500\u253c\u255e\u255f'
|
||||
'\u255a\u2554\u2569\u2566\u2560\u2550\u256c\u2567'
|
||||
'\u2568\u2564\u2565\u2559\u2558\u2552\u2553\u256b'
|
||||
'\u256a\u2518\u250c\u2588\u2584\u258c\u2590\u2580'
|
||||
'\u03b1\xdf\u0393\u03c0\u03a3\u03c3\xb5\u03c4'
|
||||
'\u03a6\u0398\u03a9\u03b4\u221e\u03c6\u03b5\u2229'
|
||||
'\u2261\xb1\u2265\u2264\u2320\u2321\xf7\u2248'
|
||||
'\xb0\u2219\xb7\u221a\u207f\xb2\u25a0\xa0'
|
||||
)
|
||||
|
||||
_importing_zlib = False
|
||||
|
||||
# Return the zlib.decompress function object, or NULL if zlib couldn't
|
||||
# be imported. The function is cached when found, so subsequent calls
|
||||
# don't import zlib again.
|
||||
def _get_decompress_func():
|
||||
global _importing_zlib
|
||||
if _importing_zlib:
|
||||
# Someone has a zlib.py[co] in their Zip file
|
||||
# let's avoid a stack overflow.
|
||||
_bootstrap._verbose_message('zipimport: zlib UNAVAILABLE')
|
||||
raise ZipImportError("can't decompress data; zlib not available")
|
||||
|
||||
_importing_zlib = True
|
||||
try:
|
||||
from zlib import decompress
|
||||
except Exception:
|
||||
_bootstrap._verbose_message('zipimport: zlib UNAVAILABLE')
|
||||
raise ZipImportError("can't decompress data; zlib not available")
|
||||
finally:
|
||||
_importing_zlib = False
|
||||
|
||||
_bootstrap._verbose_message('zipimport: zlib available')
|
||||
return decompress
|
||||
|
||||
# Given a path to a Zip file and a toc_entry, return the (uncompressed) data.
|
||||
def _get_data(archive, toc_entry):
|
||||
datapath, compress, data_size, file_size, file_offset, time, date, crc = toc_entry
|
||||
if data_size < 0:
|
||||
raise ZipImportError('negative data size')
|
||||
|
||||
with _io.open(archive, 'rb') as fp:
|
||||
# Check to make sure the local file header is correct
|
||||
try:
|
||||
fp.seek(file_offset)
|
||||
except OSError:
|
||||
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
|
||||
buffer = fp.read(30)
|
||||
if len(buffer) != 30:
|
||||
raise EOFError('EOF read where not expected')
|
||||
|
||||
if buffer[:4] != b'PK\x03\x04':
|
||||
# Bad: Local File Header
|
||||
raise ZipImportError(f'bad local file header: {archive!r}', path=archive)
|
||||
|
||||
name_size = _unpack_uint16(buffer[26:28])
|
||||
extra_size = _unpack_uint16(buffer[28:30])
|
||||
header_size = 30 + name_size + extra_size
|
||||
file_offset += header_size # Start of file data
|
||||
try:
|
||||
fp.seek(file_offset)
|
||||
except OSError:
|
||||
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
|
||||
raw_data = fp.read(data_size)
|
||||
if len(raw_data) != data_size:
|
||||
raise OSError("zipimport: can't read data")
|
||||
|
||||
if compress == 0:
|
||||
# data is not compressed
|
||||
return raw_data
|
||||
|
||||
# Decompress with zlib
|
||||
try:
|
||||
decompress = _get_decompress_func()
|
||||
except Exception:
|
||||
raise ZipImportError("can't decompress data; zlib not available")
|
||||
return decompress(raw_data, -15)
|
||||
|
||||
|
||||
# Lenient date/time comparison function. The precision of the mtime
|
||||
# in the archive is lower than the mtime stored in a .pyc: we
|
||||
# must allow a difference of at most one second.
|
||||
def _eq_mtime(t1, t2):
|
||||
# dostime only stores even seconds, so be lenient
|
||||
return abs(t1 - t2) <= 1
|
||||
|
||||
# Given the contents of a .py[co] file, unmarshal the data
|
||||
# and return the code object. Return None if it the magic word doesn't
|
||||
# match (we do this instead of raising an exception as we fall back
|
||||
# to .py if available and we don't want to mask other errors).
|
||||
def _unmarshal_code(pathname, data, mtime):
|
||||
if len(data) < 16:
|
||||
raise ZipImportError('bad pyc data')
|
||||
|
||||
if data[:4] != _bootstrap_external.MAGIC_NUMBER:
|
||||
_bootstrap._verbose_message('{!r} has bad magic', pathname)
|
||||
return None # signal caller to try alternative
|
||||
|
||||
flags = _unpack_uint32(data[4:8])
|
||||
if flags != 0:
|
||||
# Hash-based pyc. We currently refuse to handle checked hash-based
|
||||
# pycs. We could validate hash-based pycs against the source, but it
|
||||
# seems likely that most people putting hash-based pycs in a zipfile
|
||||
# will use unchecked ones.
|
||||
if (_imp.check_hash_based_pycs != 'never' and
|
||||
(flags != 0x1 or _imp.check_hash_based_pycs == 'always')):
|
||||
return None
|
||||
elif mtime != 0 and not _eq_mtime(_unpack_uint32(data[8:12]), mtime):
|
||||
_bootstrap._verbose_message('{!r} has bad mtime', pathname)
|
||||
return None # signal caller to try alternative
|
||||
|
||||
# XXX the pyc's size field is ignored; timestamp collisions are probably
|
||||
# unimportant with zip files.
|
||||
code = marshal.loads(data[16:])
|
||||
if not isinstance(code, _code_type):
|
||||
raise TypeError(f'compiled module {pathname!r} is not a code object')
|
||||
return code
|
||||
|
||||
_code_type = type(_unmarshal_code.__code__)
|
||||
|
||||
|
||||
# Replace any occurrences of '\r\n?' in the input string with '\n'.
|
||||
# This converts DOS and Mac line endings to Unix line endings.
|
||||
def _normalize_line_endings(source):
|
||||
source = source.replace(b'\r\n', b'\n')
|
||||
source = source.replace(b'\r', b'\n')
|
||||
return source
|
||||
|
||||
# Given a string buffer containing Python source code, compile it
|
||||
# and return a code object.
|
||||
def _compile_source(pathname, source):
|
||||
source = _normalize_line_endings(source)
|
||||
return compile(source, pathname, 'exec', dont_inherit=True)
|
||||
|
||||
# Convert the date/time values found in the Zip archive to a value
|
||||
# that's compatible with the time stamp stored in .pyc files.
|
||||
def _parse_dostime(d, t):
|
||||
return time.mktime((
|
||||
(d >> 9) + 1980, # bits 9..15: year
|
||||
(d >> 5) & 0xF, # bits 5..8: month
|
||||
d & 0x1F, # bits 0..4: day
|
||||
t >> 11, # bits 11..15: hours
|
||||
(t >> 5) & 0x3F, # bits 8..10: minutes
|
||||
(t & 0x1F) * 2, # bits 0..7: seconds / 2
|
||||
-1, -1, -1))
|
||||
|
||||
# Given a path to a .pyc file in the archive, return the
|
||||
# modification time of the matching .py file, or 0 if no source
|
||||
# is available.
|
||||
def _get_mtime_of_source(self, path):
|
||||
try:
|
||||
# strip 'c' or 'o' from *.py[co]
|
||||
assert path[-1:] in ('c', 'o')
|
||||
path = path[:-1]
|
||||
toc_entry = self._files[path]
|
||||
# fetch the time stamp of the .py file for comparison
|
||||
# with an embedded pyc time stamp
|
||||
time = toc_entry[5]
|
||||
date = toc_entry[6]
|
||||
return _parse_dostime(date, time)
|
||||
except (KeyError, IndexError, TypeError):
|
||||
return 0
|
||||
|
||||
# Get the code object associated with the module specified by
|
||||
# 'fullname'.
|
||||
def _get_module_code(self, fullname):
|
||||
path = _get_module_path(self, fullname)
|
||||
for suffix, isbytecode, ispackage in _zip_searchorder:
|
||||
fullpath = path + suffix
|
||||
_bootstrap._verbose_message('trying {}{}{}', self.archive, path_sep, fullpath, verbosity=2)
|
||||
try:
|
||||
toc_entry = self._files[fullpath]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
modpath = toc_entry[0]
|
||||
data = _get_data(self.archive, toc_entry)
|
||||
if isbytecode:
|
||||
mtime = _get_mtime_of_source(self, fullpath)
|
||||
code = _unmarshal_code(modpath, data, mtime)
|
||||
else:
|
||||
code = _compile_source(modpath, data)
|
||||
if code is None:
|
||||
# bad magic number or non-matching mtime
|
||||
# in byte code, try next
|
||||
continue
|
||||
modpath = toc_entry[0]
|
||||
return code, ispackage, modpath
|
||||
else:
|
||||
raise ZipImportError(f"can't find module {fullname!r}", name=fullname)
|
||||
|
||||
|
||||
class _ZipImportResourceReader:
|
||||
"""Private class used to support ZipImport.get_resource_reader().
|
||||
|
||||
This class is allowed to reference all the innards and private parts of
|
||||
the zipimporter.
|
||||
"""
|
||||
_registered = False
|
||||
|
||||
def __init__(self, zipimporter, fullname):
|
||||
self.zipimporter = zipimporter
|
||||
self.fullname = fullname
|
||||
|
||||
def open_resource(self, resource):
|
||||
fullname_as_path = self.fullname.replace('.', '/')
|
||||
path = f'{fullname_as_path}/{resource}'
|
||||
from io import BytesIO
|
||||
try:
|
||||
return BytesIO(self.zipimporter.get_data(path))
|
||||
except OSError:
|
||||
raise FileNotFoundError(path)
|
||||
|
||||
def resource_path(self, resource):
|
||||
# All resources are in the zip file, so there is no path to the file.
|
||||
# Raising FileNotFoundError tells the higher level API to extract the
|
||||
# binary data and create a temporary file.
|
||||
raise FileNotFoundError
|
||||
|
||||
def is_resource(self, name):
|
||||
# Maybe we could do better, but if we can get the data, it's a
|
||||
# resource. Otherwise it isn't.
|
||||
fullname_as_path = self.fullname.replace('.', '/')
|
||||
path = f'{fullname_as_path}/{name}'
|
||||
try:
|
||||
self.zipimporter.get_data(path)
|
||||
except OSError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def contents(self):
|
||||
# This is a bit convoluted, because fullname will be a module path,
|
||||
# but _files is a list of file names relative to the top of the
|
||||
# archive's namespace. We want to compare file paths to find all the
|
||||
# names of things inside the module represented by fullname. So we
|
||||
# turn the module path of fullname into a file path relative to the
|
||||
# top of the archive, and then we iterate through _files looking for
|
||||
# names inside that "directory".
|
||||
from pathlib import Path
|
||||
fullname_path = Path(self.zipimporter.get_filename(self.fullname))
|
||||
relative_path = fullname_path.relative_to(self.zipimporter.archive)
|
||||
# Don't forget that fullname names a package, so its path will include
|
||||
# __init__.py, which we want to ignore.
|
||||
assert relative_path.name == '__init__.py'
|
||||
package_path = relative_path.parent
|
||||
subdirs_seen = set()
|
||||
for filename in self.zipimporter._files:
|
||||
try:
|
||||
relative = Path(filename).relative_to(package_path)
|
||||
except ValueError:
|
||||
continue
|
||||
# If the path of the file (which is relative to the top of the zip
|
||||
# namespace), relative to the package given when the resource
|
||||
# reader was created, has a parent, then it's a name in a
|
||||
# subdirectory and thus we skip it.
|
||||
parent_name = relative.parent.name
|
||||
if len(parent_name) == 0:
|
||||
yield relative.name
|
||||
elif parent_name not in subdirs_seen:
|
||||
subdirs_seen.add(parent_name)
|
||||
yield parent_name
|
2
maubot/loader/__init__.py
Normal file
2
maubot/loader/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .abc import PluginLoader, PluginClass
|
||||
from .zip import ZippedPluginLoader, MaubotZipImportError
|
60
maubot/loader/abc.py
Normal file
60
maubot/loader/abc.py
Normal file
@ -0,0 +1,60 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2018 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import TypeVar, Type, Dict, Set, TYPE_CHECKING
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from ..plugin_base import Plugin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..plugin import PluginInstance
|
||||
|
||||
PluginClass = TypeVar("PluginClass", bound=Plugin)
|
||||
|
||||
|
||||
class IDConflictError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PluginLoader(ABC):
|
||||
id_cache: Dict[str, 'PluginLoader'] = {}
|
||||
|
||||
references: Set['PluginInstance']
|
||||
id: str
|
||||
version: str
|
||||
|
||||
def __init__(self):
|
||||
self.references = set()
|
||||
|
||||
@classmethod
|
||||
def find(cls, plugin_id: str) -> 'PluginLoader':
|
||||
return cls.id_cache[plugin_id]
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def source(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> Type[PluginClass]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reload(self) -> Type[PluginClass]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unload(self) -> None:
|
||||
pass
|
190
maubot/loader/zip.py
Normal file
190
maubot/loader/zip.py
Normal file
@ -0,0 +1,190 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2018 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, List, Type
|
||||
from zipfile import ZipFile, BadZipFile
|
||||
import configparser
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
from ..lib.zipimport import zipimporter, ZipImportError
|
||||
from ..plugin_base import Plugin
|
||||
from .abc import PluginLoader, PluginClass, IDConflictError
|
||||
|
||||
|
||||
class MaubotZipImportError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ZippedPluginLoader(PluginLoader):
|
||||
path_cache: Dict[str, 'ZippedPluginLoader'] = {}
|
||||
log = logging.getLogger("maubot.loader.zip")
|
||||
|
||||
path: str
|
||||
id: str
|
||||
version: str
|
||||
modules: List[str]
|
||||
main_class: str
|
||||
main_module: str
|
||||
_loaded: Type[PluginClass]
|
||||
_importer: zipimporter
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
super().__init__()
|
||||
self.path = path
|
||||
self.id = None
|
||||
self._loaded = None
|
||||
self._importer = None
|
||||
self._load_meta()
|
||||
self._run_preload_checks(self._get_importer())
|
||||
try:
|
||||
existing = self.id_cache[self.id]
|
||||
raise IDConflictError(f"Plugin with id {self.id} already loaded from {existing.source}")
|
||||
except KeyError:
|
||||
pass
|
||||
self.path_cache[self.path] = self
|
||||
self.id_cache[self.id] = self
|
||||
self.log.debug(f"Preloaded plugin {self.id} from {self.path}")
|
||||
|
||||
@classmethod
|
||||
def get(cls, path: str) -> 'ZippedPluginLoader':
|
||||
try:
|
||||
return cls.path_cache[path]
|
||||
except KeyError:
|
||||
return cls(path)
|
||||
|
||||
@property
|
||||
def source(self) -> str:
|
||||
return self.path
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("<ZippedPlugin "
|
||||
f"path='{self.path}' "
|
||||
f"id='{self.id}' "
|
||||
f"loaded={self._loaded is not None}>")
|
||||
|
||||
def _load_meta(self) -> None:
|
||||
try:
|
||||
file = ZipFile(self.path)
|
||||
data = file.read("maubot.ini")
|
||||
except FileNotFoundError as e:
|
||||
raise MaubotZipImportError("Maubot plugin not found") from e
|
||||
except BadZipFile as e:
|
||||
raise MaubotZipImportError("File is not a maubot plugin") from e
|
||||
except KeyError as e:
|
||||
raise MaubotZipImportError("File does not contain a maubot plugin definition") from e
|
||||
config = configparser.ConfigParser()
|
||||
try:
|
||||
config.read_string(data.decode("utf-8"), source=f"{self.path}/maubot.ini")
|
||||
meta = config["maubot"]
|
||||
meta_id = meta["ID"]
|
||||
version = meta["Version"]
|
||||
modules = [mod.strip() for mod in meta["Modules"].split(",")]
|
||||
main_class = meta["MainClass"]
|
||||
main_module = modules[-1]
|
||||
if "/" in main_class:
|
||||
main_module, main_class = main_class.split("/")[:2]
|
||||
except (configparser.Error, KeyError, IndexError, ValueError) as e:
|
||||
raise MaubotZipImportError("Maubot plugin definition in file is invalid") from e
|
||||
if self.id and meta_id != self.id:
|
||||
raise MaubotZipImportError("Maubot plugin ID changed during reload")
|
||||
self.id, self.version, self.modules = meta_id, version, modules
|
||||
self.main_class, self.main_module = main_class, main_module
|
||||
|
||||
def _get_importer(self, reset_cache: bool = False) -> zipimporter:
|
||||
try:
|
||||
if not self._importer:
|
||||
self._importer = zipimporter(self.path)
|
||||
if reset_cache:
|
||||
self._importer.reset_cache()
|
||||
return self._importer
|
||||
except ZipImportError as e:
|
||||
raise MaubotZipImportError("File not found or not a maubot plugin") from e
|
||||
|
||||
def _run_preload_checks(self, importer: zipimporter) -> None:
|
||||
try:
|
||||
code = importer.get_code(self.main_module.replace(".", "/"))
|
||||
if self.main_class not in code.co_names:
|
||||
raise MaubotZipImportError(
|
||||
f"Main class {self.main_class} not in {self.main_module}")
|
||||
except ZipImportError as e:
|
||||
raise MaubotZipImportError(
|
||||
f"Main module {self.main_module} not found in file") from e
|
||||
for module in self.modules:
|
||||
try:
|
||||
importer.find_module(module)
|
||||
except ZipImportError as e:
|
||||
raise MaubotZipImportError(f"Module {module} not found in file") from e
|
||||
|
||||
def load(self, reset_cache: bool = False) -> Type[PluginClass]:
|
||||
if self._loaded is not None and not reset_cache:
|
||||
return self._loaded
|
||||
importer = self._get_importer(reset_cache=reset_cache)
|
||||
self._run_preload_checks(importer)
|
||||
if reset_cache:
|
||||
self.log.debug(f"Preloaded plugin {self.id} from {self.path}")
|
||||
for module in self.modules:
|
||||
importer.load_module(module)
|
||||
main_mod = sys.modules[self.main_module]
|
||||
plugin = getattr(main_mod, self.main_class)
|
||||
if not issubclass(plugin, Plugin):
|
||||
raise MaubotZipImportError("Main class of plugin does not extend maubot.Plugin")
|
||||
self._loaded = plugin
|
||||
self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}")
|
||||
return plugin
|
||||
|
||||
def reload(self) -> Type[PluginClass]:
|
||||
self.unload()
|
||||
return self.load(reset_cache=True)
|
||||
|
||||
def unload(self) -> None:
|
||||
for name, mod in list(sys.modules.items()):
|
||||
if getattr(mod, "__file__", "").startswith(self.path):
|
||||
del sys.modules[name]
|
||||
self._loaded = None
|
||||
self.log.debug(f"Unloaded plugin {self.id} at {self.path}")
|
||||
|
||||
def destroy(self) -> None:
|
||||
self.unload()
|
||||
try:
|
||||
del self.path_cache[self.path]
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
del self.id_cache[self.id]
|
||||
except KeyError:
|
||||
pass
|
||||
self.id = None
|
||||
self.path = None
|
||||
self.version = None
|
||||
self.modules = None
|
||||
if self._importer:
|
||||
self._importer.remove_cache()
|
||||
self._importer = None
|
||||
self._loaded = None
|
||||
|
||||
@classmethod
|
||||
def load_all(cls, *args: str) -> None:
|
||||
cls.log.debug("Preloading plugins...")
|
||||
for directory in args:
|
||||
for file in os.listdir(directory):
|
||||
if not file.endswith(".mbp"):
|
||||
continue
|
||||
path = os.path.join(directory, file)
|
||||
try:
|
||||
ZippedPluginLoader.get(path)
|
||||
except (MaubotZipImportError, IDConflictError):
|
||||
cls.log.exception(f"Failed to load plugin at {path}")
|
126
maubot/matrix.py
Normal file
126
maubot/matrix.py
Normal file
@ -0,0 +1,126 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2018 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, List, Union, Callable, Awaitable
|
||||
import attr
|
||||
import commonmark
|
||||
|
||||
from mautrix import Client as MatrixClient
|
||||
from mautrix.client import EventHandler
|
||||
from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent,
|
||||
MessageType, TextMessageEventContent, Format)
|
||||
|
||||
from .command_spec import ParsedCommand, CommandSpec
|
||||
|
||||
|
||||
class MaubotMessageEvent(MessageEvent):
|
||||
_client: MatrixClient
|
||||
|
||||
def __init__(self, base: MessageEvent, client: MatrixClient):
|
||||
super().__init__(**{a.name.lstrip("_"): getattr(base, a.name)
|
||||
for a in attr.fields(MessageEvent)})
|
||||
self._client = client
|
||||
|
||||
def respond(self, content: Union[str, MessageEventContent],
|
||||
event_type: EventType = EventType.ROOM_MESSAGE,
|
||||
markdown: bool = True) -> Awaitable[EventID]:
|
||||
if isinstance(content, str):
|
||||
content = TextMessageEventContent(msgtype=MessageType.NOTICE, body=content)
|
||||
if markdown:
|
||||
content.format = Format.HTML
|
||||
content.formatted_body = commonmark.commonmark(content.body)
|
||||
return self._client.send_message_event(self.room_id, event_type, content)
|
||||
|
||||
def reply(self, content: Union[str, MessageEventContent],
|
||||
event_type: EventType = EventType.ROOM_MESSAGE,
|
||||
markdown: bool = True) -> Awaitable[EventID]:
|
||||
if isinstance(content, str):
|
||||
content = TextMessageEventContent(msgtype=MessageType.NOTICE, body=content)
|
||||
if markdown:
|
||||
content.format = Format.HTML
|
||||
content.formatted_body = commonmark.commonmark(content.body)
|
||||
content.set_reply(self)
|
||||
return self._client.send_message_event(self.room_id, event_type, content)
|
||||
|
||||
def mark_read(self) -> Awaitable[None]:
|
||||
return self._client.send_receipt(self.room_id, self.event_id, "m.read")
|
||||
|
||||
|
||||
class MaubotMatrixClient(MatrixClient):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.command_handlers: Dict[str, List[EventHandler]] = {}
|
||||
self.commands: List[ParsedCommand] = []
|
||||
self.command_specs: Dict[str, CommandSpec] = {}
|
||||
|
||||
self.add_event_handler(self._command_event_handler, EventType.ROOM_MESSAGE)
|
||||
|
||||
def set_command_spec(self, plugin_id: str, spec: CommandSpec) -> None:
|
||||
self.command_specs[plugin_id] = spec
|
||||
self._reparse_command_specs()
|
||||
|
||||
def _reparse_command_specs(self) -> None:
|
||||
self.commands = [parsed_command
|
||||
for spec in self.command_specs.values()
|
||||
for parsed_command in spec.parse()]
|
||||
|
||||
def remove_command_spec(self, plugin_id: str) -> None:
|
||||
try:
|
||||
del self.command_specs[plugin_id]
|
||||
self._reparse_command_specs()
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
async def _command_event_handler(self, evt: MessageEvent) -> None:
|
||||
if evt.sender == self.mxid or evt.content.msgtype != MessageType.TEXT:
|
||||
return
|
||||
for command in self.commands:
|
||||
if command.match(evt):
|
||||
await self._trigger_command(command, evt)
|
||||
return
|
||||
|
||||
async def _trigger_command(self, command: ParsedCommand, evt: MessageEvent) -> None:
|
||||
for handler in self.command_handlers.get(command.name, []):
|
||||
await handler(evt)
|
||||
|
||||
def on(self, var: Union[EventHandler, EventType, str]
|
||||
) -> Union[EventHandler, Callable[[EventHandler], EventHandler]]:
|
||||
if isinstance(var, str):
|
||||
def decorator(func: EventHandler) -> EventHandler:
|
||||
self.add_command_handler(var, func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
return super().on(var)
|
||||
|
||||
def add_command_handler(self, command: str, handler: EventHandler) -> None:
|
||||
self.command_handlers.setdefault(command, []).append(handler)
|
||||
|
||||
def remove_command_handler(self, command: str, handler: EventHandler) -> None:
|
||||
try:
|
||||
self.command_handlers[command].remove(handler)
|
||||
except (KeyError, ValueError):
|
||||
pass
|
||||
|
||||
async def call_handlers(self, event: Event) -> None:
|
||||
if isinstance(event, MessageEvent):
|
||||
event = MaubotMessageEvent(event, self)
|
||||
return await super().call_handlers(event)
|
||||
|
||||
async def get_event(self, room_id: RoomID, event_id: EventID) -> Event:
|
||||
event = await super().get_event(room_id, event_id)
|
||||
if isinstance(event, MessageEvent):
|
||||
return MaubotMessageEvent(event, self)
|
||||
return event
|
122
maubot/plugin.py
Normal file
122
maubot/plugin.py
Normal file
@ -0,0 +1,122 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2018 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import Dict, List, Optional
|
||||
import logging
|
||||
|
||||
from mautrix.types import UserID
|
||||
|
||||
from .db import DBPlugin
|
||||
from .client import Client
|
||||
from .loader import PluginLoader
|
||||
from .plugin_base import Plugin
|
||||
|
||||
log = logging.getLogger("maubot.plugin")
|
||||
|
||||
|
||||
class PluginInstance:
|
||||
cache: Dict[str, 'PluginInstance'] = {}
|
||||
plugin_directories: List[str] = []
|
||||
|
||||
log: logging.Logger
|
||||
loader: PluginLoader
|
||||
client: Client
|
||||
plugin: Plugin
|
||||
|
||||
def __init__(self, db_instance: DBPlugin):
|
||||
self.db_instance = db_instance
|
||||
self.log = logging.getLogger(f"maubot.plugin.{self.id}")
|
||||
self.cache[self.id] = self
|
||||
|
||||
def load(self) -> None:
|
||||
try:
|
||||
self.loader = PluginLoader.find(self.type)
|
||||
except KeyError:
|
||||
self.log.error(f"Failed to find loader for type {self.type}")
|
||||
self.enabled = False
|
||||
return
|
||||
self.client = Client.get(self.primary_user)
|
||||
if not self.client:
|
||||
self.log.error(f"Failed to get client for user {self.primary_user}")
|
||||
self.enabled = False
|
||||
self.log.debug("Plugin instance dependencies loaded")
|
||||
|
||||
async def start(self) -> None:
|
||||
if not self.enabled:
|
||||
self.log.warn(f"Plugin disabled, not starting.")
|
||||
return
|
||||
cls = self.loader.load()
|
||||
self.plugin = cls(self.client.client, self.id, self.log)
|
||||
self.loader.references |= {self}
|
||||
await self.plugin.start()
|
||||
self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} "
|
||||
f"with user {self.client.id}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
self.log.debug("Stopping plugin instance...")
|
||||
self.loader.references -= {self}
|
||||
await self.plugin.stop()
|
||||
self.plugin = None
|
||||
|
||||
@classmethod
|
||||
def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None
|
||||
) -> Optional['PluginInstance']:
|
||||
try:
|
||||
return cls.cache[instance_id]
|
||||
except KeyError:
|
||||
db_instance = db_instance or DBPlugin.query.get(instance_id)
|
||||
if not db_instance:
|
||||
return None
|
||||
return PluginInstance(db_instance)
|
||||
|
||||
@classmethod
|
||||
def all(cls) -> List['PluginInstance']:
|
||||
return [cls.get(plugin.id, plugin) for plugin in DBPlugin.query.all()]
|
||||
|
||||
# region Properties
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.db_instance.id
|
||||
|
||||
@id.setter
|
||||
def id(self, value: str) -> None:
|
||||
self.db_instance.id = value
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.db_instance.type
|
||||
|
||||
@type.setter
|
||||
def type(self, value: str) -> None:
|
||||
self.db_instance.type = value
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.db_instance.enabled
|
||||
|
||||
@enabled.setter
|
||||
def enabled(self, value: bool) -> None:
|
||||
self.db_instance.enabled = value
|
||||
|
||||
@property
|
||||
def primary_user(self) -> UserID:
|
||||
return self.db_instance.primary_user
|
||||
|
||||
@primary_user.setter
|
||||
def primary_user(self, value: UserID) -> None:
|
||||
self.db_instance.primary_user = value
|
||||
|
||||
# endregion
|
42
maubot/plugin_base.py
Normal file
42
maubot/plugin_base.py
Normal file
@ -0,0 +1,42 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2018 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from typing import TYPE_CHECKING
|
||||
from logging import Logger
|
||||
from abc import ABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import MaubotMatrixClient
|
||||
from .command_spec import CommandSpec
|
||||
|
||||
|
||||
class Plugin(ABC):
|
||||
client: 'MaubotMatrixClient'
|
||||
id: str
|
||||
log: Logger
|
||||
|
||||
def __init__(self, client: 'MaubotMatrixClient', plugin_instance_id: str, log: Logger) -> None:
|
||||
self.client = client
|
||||
self.id = plugin_instance_id
|
||||
self.log = log
|
||||
|
||||
def set_command_spec(self, spec: 'CommandSpec') -> None:
|
||||
self.client.set_command_spec(self.id, spec)
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
pass
|
57
maubot/server.py
Normal file
57
maubot/server.py
Normal file
@ -0,0 +1,57 @@
|
||||
# maubot - A plugin-based Matrix bot system.
|
||||
# Copyright (C) 2018 Tulir Asokan
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from aiohttp import web
|
||||
import asyncio
|
||||
|
||||
from mautrix.api import PathBuilder, Method
|
||||
|
||||
from .config import Config
|
||||
from .__meta__ import __version__
|
||||
|
||||
|
||||
class MaubotServer:
|
||||
def __init__(self, config: Config, loop: asyncio.AbstractEventLoop):
|
||||
self.loop = loop or asyncio.get_event_loop()
|
||||
self.app = web.Application(loop=self.loop)
|
||||
self.config = config
|
||||
|
||||
path = PathBuilder(config["server.base_path"])
|
||||
self.add_route(Method.GET, path.version, self.version)
|
||||
|
||||
as_path = PathBuilder(config["server.appservice_base_path"])
|
||||
self.add_route(Method.PUT, as_path.transactions, self.handle_transaction)
|
||||
|
||||
self.runner = web.AppRunner(self.app)
|
||||
|
||||
def add_route(self, method: Method, path: PathBuilder, handler) -> None:
|
||||
self.app.router.add_route(method.value, str(path), handler)
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.runner.setup()
|
||||
site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"])
|
||||
await site.start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
await self.runner.cleanup()
|
||||
|
||||
@staticmethod
|
||||
async def version(_: web.Request) -> web.Response:
|
||||
return web.json_response({
|
||||
"version": __version__
|
||||
})
|
||||
|
||||
async def handle_transaction(self, request: web.Request) -> web.Response:
|
||||
return web.Response(status=501)
|
31
plugin.go
31
plugin.go
@ -1,31 +0,0 @@
|
||||
// maubot - A plugin-based Matrix bot system written in Go.
|
||||
// Copyright (C) 2018 Tulir Asokan
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package maubot
|
||||
|
||||
type Plugin interface {
|
||||
Start()
|
||||
Stop()
|
||||
}
|
||||
|
||||
type PluginCreatorFunc func(client MatrixClient, logger Logger) Plugin
|
||||
|
||||
type PluginCreator struct {
|
||||
Create PluginCreatorFunc
|
||||
Name string
|
||||
Version string
|
||||
Path string
|
||||
}
|
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
mautrix
|
||||
aiohttp
|
||||
SQLAlchemy
|
||||
alembic
|
||||
commonmark
|
||||
ruamel.yaml
|
||||
attrs
|
50
setup.py
Normal file
50
setup.py
Normal file
@ -0,0 +1,50 @@
|
||||
import setuptools
|
||||
import os
|
||||
|
||||
path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "maubot", "__meta__.py")
|
||||
__version__ = "UNKNOWN"
|
||||
with open(path) as f:
|
||||
exec(f.read())
|
||||
|
||||
setuptools.setup(
|
||||
name="maubot",
|
||||
version=__version__,
|
||||
url="https://github.com/maubot/maubot",
|
||||
|
||||
author="Tulir Asokan",
|
||||
author_email="tulir@maunium.net",
|
||||
|
||||
description="A plugin-based Matrix bot system.",
|
||||
long_description=open("README.md").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
packages=setuptools.find_packages(),
|
||||
|
||||
install_requires=[
|
||||
"mautrix>=0.4,<0.5",
|
||||
"aiohttp>=3.0.1,<4",
|
||||
"SQLAlchemy>=1.2.3,<2",
|
||||
"alembic>=1.0.0,<2",
|
||||
"commonmark>=0.8.1,<1",
|
||||
"ruamel.yaml>=0.15.35,<0.16",
|
||||
"attrs>=18.2.0,<19",
|
||||
],
|
||||
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)",
|
||||
"Topic :: Communications :: Chat",
|
||||
"Framework :: AsyncIO",
|
||||
"Programming Language :: Python",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
],
|
||||
entry_points="""
|
||||
[console_scripts]
|
||||
maubot=maubot.__main__:main
|
||||
""",
|
||||
data_files=[
|
||||
(".", ["example-config.yaml"]),
|
||||
],
|
||||
)
|
19
vendor/github.com/gorilla/context/.travis.yml
generated
vendored
19
vendor/github.com/gorilla/context/.travis.yml
generated
vendored
@ -1,19 +0,0 @@
|
||||
language: go
|
||||
sudo: false
|
||||
|
||||
matrix:
|
||||
include:
|
||||
- go: 1.3
|
||||
- go: 1.4
|
||||
- go: 1.5
|
||||
- go: 1.6
|
||||
- go: 1.7
|
||||
- go: tip
|
||||
allow_failures:
|
||||
- go: tip
|
||||
|
||||
script:
|
||||
- go get -t -v ./...
|
||||
- diff -u <(echo -n) <(gofmt -d .)
|
||||
- go vet $(go list ./... | grep -v /vendor/)
|
||||
- go test -v -race ./...
|
27
vendor/github.com/gorilla/context/LICENSE
generated
vendored
27
vendor/github.com/gorilla/context/LICENSE
generated
vendored
@ -1,27 +0,0 @@
|
||||
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
10
vendor/github.com/gorilla/context/README.md
generated
vendored
10
vendor/github.com/gorilla/context/README.md
generated
vendored
@ -1,10 +0,0 @@
|
||||
context
|
||||
=======
|
||||
[![Build Status](https://travis-ci.org/gorilla/context.png?branch=master)](https://travis-ci.org/gorilla/context)
|
||||
|
||||
gorilla/context is a general purpose registry for global request variables.
|
||||
|
||||
> Note: gorilla/context, having been born well before `context.Context` existed, does not play well
|
||||
> with the shallow copying of the request that [`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext) (added to net/http Go 1.7 onwards) performs. You should either use *just* gorilla/context, or moving forward, the new `http.Request.Context()`.
|
||||
|
||||
Read the full documentation here: http://www.gorillatoolkit.org/pkg/context
|
143
vendor/github.com/gorilla/context/context.go
generated
vendored
143
vendor/github.com/gorilla/context/context.go
generated
vendored
@ -1,143 +0,0 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package context
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
mutex sync.RWMutex
|
||||
data = make(map[*http.Request]map[interface{}]interface{})
|
||||
datat = make(map[*http.Request]int64)
|
||||
)
|
||||
|
||||
// Set stores a value for a given key in a given request.
|
||||
func Set(r *http.Request, key, val interface{}) {
|
||||
mutex.Lock()
|
||||
if data[r] == nil {
|
||||
data[r] = make(map[interface{}]interface{})
|
||||
datat[r] = time.Now().Unix()
|
||||
}
|
||||
data[r][key] = val
|
||||
mutex.Unlock()
|
||||
}
|
||||
|
||||
// Get returns a value stored for a given key in a given request.
|
||||
func Get(r *http.Request, key interface{}) interface{} {
|
||||
mutex.RLock()
|
||||
if ctx := data[r]; ctx != nil {
|
||||
value := ctx[key]
|
||||
mutex.RUnlock()
|
||||
return value
|
||||
}
|
||||
mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetOk returns stored value and presence state like multi-value return of map access.
|
||||
func GetOk(r *http.Request, key interface{}) (interface{}, bool) {
|
||||
mutex.RLock()
|
||||
if _, ok := data[r]; ok {
|
||||
value, ok := data[r][key]
|
||||
mutex.RUnlock()
|
||||
return value, ok
|
||||
}
|
||||
mutex.RUnlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// GetAll returns all stored values for the request as a map. Nil is returned for invalid requests.
|
||||
func GetAll(r *http.Request) map[interface{}]interface{} {
|
||||
mutex.RLock()
|
||||
if context, ok := data[r]; ok {
|
||||
result := make(map[interface{}]interface{}, len(context))
|
||||
for k, v := range context {
|
||||
result[k] = v
|
||||
}
|
||||
mutex.RUnlock()
|
||||
return result
|
||||
}
|
||||
mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllOk returns all stored values for the request as a map and a boolean value that indicates if
|
||||
// the request was registered.
|
||||
func GetAllOk(r *http.Request) (map[interface{}]interface{}, bool) {
|
||||
mutex.RLock()
|
||||
context, ok := data[r]
|
||||
result := make(map[interface{}]interface{}, len(context))
|
||||
for k, v := range context {
|
||||
result[k] = v
|
||||
}
|
||||
mutex.RUnlock()
|
||||
return result, ok
|
||||
}
|
||||
|
||||
// Delete removes a value stored for a given key in a given request.
|
||||
func Delete(r *http.Request, key interface{}) {
|
||||
mutex.Lock()
|
||||
if data[r] != nil {
|
||||
delete(data[r], key)
|
||||
}
|
||||
mutex.Unlock()
|
||||
}
|
||||
|
||||
// Clear removes all values stored for a given request.
|
||||
//
|
||||
// This is usually called by a handler wrapper to clean up request
|
||||
// variables at the end of a request lifetime. See ClearHandler().
|
||||
func Clear(r *http.Request) {
|
||||
mutex.Lock()
|
||||
clear(r)
|
||||
mutex.Unlock()
|
||||
}
|
||||
|
||||
// clear is Clear without the lock.
|
||||
func clear(r *http.Request) {
|
||||
delete(data, r)
|
||||
delete(datat, r)
|
||||
}
|
||||
|
||||
// Purge removes request data stored for longer than maxAge, in seconds.
|
||||
// It returns the amount of requests removed.
|
||||
//
|
||||
// If maxAge <= 0, all request data is removed.
|
||||
//
|
||||
// This is only used for sanity check: in case context cleaning was not
|
||||
// properly set some request data can be kept forever, consuming an increasing
|
||||
// amount of memory. In case this is detected, Purge() must be called
|
||||
// periodically until the problem is fixed.
|
||||
func Purge(maxAge int) int {
|
||||
mutex.Lock()
|
||||
count := 0
|
||||
if maxAge <= 0 {
|
||||
count = len(data)
|
||||
data = make(map[*http.Request]map[interface{}]interface{})
|
||||
datat = make(map[*http.Request]int64)
|
||||
} else {
|
||||
min := time.Now().Unix() - int64(maxAge)
|
||||
for r := range data {
|
||||
if datat[r] < min {
|
||||
clear(r)
|
||||
count++
|
||||
}
|
||||
}
|
||||
}
|
||||
mutex.Unlock()
|
||||
return count
|
||||
}
|
||||
|
||||
// ClearHandler wraps an http.Handler and clears request values at the end
|
||||
// of a request lifetime.
|
||||
func ClearHandler(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer Clear(r)
|
||||
h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
88
vendor/github.com/gorilla/context/doc.go
generated
vendored
88
vendor/github.com/gorilla/context/doc.go
generated
vendored
@ -1,88 +0,0 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package context stores values shared during a request lifetime.
|
||||
|
||||
Note: gorilla/context, having been born well before `context.Context` existed,
|
||||
does not play well > with the shallow copying of the request that
|
||||
[`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext)
|
||||
(added to net/http Go 1.7 onwards) performs. You should either use *just*
|
||||
gorilla/context, or moving forward, the new `http.Request.Context()`.
|
||||
|
||||
For example, a router can set variables extracted from the URL and later
|
||||
application handlers can access those values, or it can be used to store
|
||||
sessions values to be saved at the end of a request. There are several
|
||||
others common uses.
|
||||
|
||||
The idea was posted by Brad Fitzpatrick to the go-nuts mailing list:
|
||||
|
||||
http://groups.google.com/group/golang-nuts/msg/e2d679d303aa5d53
|
||||
|
||||
Here's the basic usage: first define the keys that you will need. The key
|
||||
type is interface{} so a key can be of any type that supports equality.
|
||||
Here we define a key using a custom int type to avoid name collisions:
|
||||
|
||||
package foo
|
||||
|
||||
import (
|
||||
"github.com/gorilla/context"
|
||||
)
|
||||
|
||||
type key int
|
||||
|
||||
const MyKey key = 0
|
||||
|
||||
Then set a variable. Variables are bound to an http.Request object, so you
|
||||
need a request instance to set a value:
|
||||
|
||||
context.Set(r, MyKey, "bar")
|
||||
|
||||
The application can later access the variable using the same key you provided:
|
||||
|
||||
func MyHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// val is "bar".
|
||||
val := context.Get(r, foo.MyKey)
|
||||
|
||||
// returns ("bar", true)
|
||||
val, ok := context.GetOk(r, foo.MyKey)
|
||||
// ...
|
||||
}
|
||||
|
||||
And that's all about the basic usage. We discuss some other ideas below.
|
||||
|
||||
Any type can be stored in the context. To enforce a given type, make the key
|
||||
private and wrap Get() and Set() to accept and return values of a specific
|
||||
type:
|
||||
|
||||
type key int
|
||||
|
||||
const mykey key = 0
|
||||
|
||||
// GetMyKey returns a value for this package from the request values.
|
||||
func GetMyKey(r *http.Request) SomeType {
|
||||
if rv := context.Get(r, mykey); rv != nil {
|
||||
return rv.(SomeType)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMyKey sets a value for this package in the request values.
|
||||
func SetMyKey(r *http.Request, val SomeType) {
|
||||
context.Set(r, mykey, val)
|
||||
}
|
||||
|
||||
Variables must be cleared at the end of a request, to remove all values
|
||||
that were stored. This can be done in an http.Handler, after a request was
|
||||
served. Just call Clear() passing the request:
|
||||
|
||||
context.Clear(r)
|
||||
|
||||
...or use ClearHandler(), which conveniently wraps an http.Handler to clear
|
||||
variables at the end of a request lifetime.
|
||||
|
||||
The Routers from the packages gorilla/mux and gorilla/pat call Clear()
|
||||
so if you are using either of them you don't need to clear the context manually.
|
||||
*/
|
||||
package context
|
23
vendor/github.com/gorilla/mux/.travis.yml
generated
vendored
23
vendor/github.com/gorilla/mux/.travis.yml
generated
vendored
@ -1,23 +0,0 @@
|
||||
language: go
|
||||
sudo: false
|
||||
|
||||
matrix:
|
||||
include:
|
||||
- go: 1.5.x
|
||||
- go: 1.6.x
|
||||
- go: 1.7.x
|
||||
- go: 1.8.x
|
||||
- go: 1.9.x
|
||||
- go: 1.10.x
|
||||
- go: tip
|
||||
allow_failures:
|
||||
- go: tip
|
||||
|
||||
install:
|
||||
- # Skip
|
||||
|
||||
script:
|
||||
- go get -t -v ./...
|
||||
- diff -u <(echo -n) <(gofmt -d .)
|
||||
- go tool vet .
|
||||
- go test -v -race ./...
|
11
vendor/github.com/gorilla/mux/ISSUE_TEMPLATE.md
generated
vendored
11
vendor/github.com/gorilla/mux/ISSUE_TEMPLATE.md
generated
vendored
@ -1,11 +0,0 @@
|
||||
**What version of Go are you running?** (Paste the output of `go version`)
|
||||
|
||||
|
||||
**What version of gorilla/mux are you at?** (Paste the output of `git rev-parse HEAD` inside `$GOPATH/src/github.com/gorilla/mux`)
|
||||
|
||||
|
||||
**Describe your problem** (and what you have tried so far)
|
||||
|
||||
|
||||
**Paste a minimal, runnable, reproduction of your issue below** (use backticks to format it)
|
||||
|
27
vendor/github.com/gorilla/mux/LICENSE
generated
vendored
27
vendor/github.com/gorilla/mux/LICENSE
generated
vendored
@ -1,27 +0,0 @@
|
||||
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
* Neither the name of Google Inc. nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
649
vendor/github.com/gorilla/mux/README.md
generated
vendored
649
vendor/github.com/gorilla/mux/README.md
generated
vendored
@ -1,649 +0,0 @@
|
||||
# gorilla/mux
|
||||
|
||||
[![GoDoc](https://godoc.org/github.com/gorilla/mux?status.svg)](https://godoc.org/github.com/gorilla/mux)
|
||||
[![Build Status](https://travis-ci.org/gorilla/mux.svg?branch=master)](https://travis-ci.org/gorilla/mux)
|
||||
[![Sourcegraph](https://sourcegraph.com/github.com/gorilla/mux/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/mux?badge)
|
||||
|
||||
![Gorilla Logo](http://www.gorillatoolkit.org/static/images/gorilla-icon-64.png)
|
||||
|
||||
http://www.gorillatoolkit.org/pkg/mux
|
||||
|
||||
Package `gorilla/mux` implements a request router and dispatcher for matching incoming requests to
|
||||
their respective handler.
|
||||
|
||||
The name mux stands for "HTTP request multiplexer". Like the standard `http.ServeMux`, `mux.Router` matches incoming requests against a list of registered routes and calls a handler for the route that matches the URL or other conditions. The main features are:
|
||||
|
||||
* It implements the `http.Handler` interface so it is compatible with the standard `http.ServeMux`.
|
||||
* Requests can be matched based on URL host, path, path prefix, schemes, header and query values, HTTP methods or using custom matchers.
|
||||
* URL hosts, paths and query values can have variables with an optional regular expression.
|
||||
* Registered URLs can be built, or "reversed", which helps maintaining references to resources.
|
||||
* Routes can be used as subrouters: nested routes are only tested if the parent route matches. This is useful to define groups of routes that share common conditions like a host, a path prefix or other repeated attributes. As a bonus, this optimizes request matching.
|
||||
|
||||
---
|
||||
|
||||
* [Install](#install)
|
||||
* [Examples](#examples)
|
||||
* [Matching Routes](#matching-routes)
|
||||
* [Static Files](#static-files)
|
||||
* [Registered URLs](#registered-urls)
|
||||
* [Walking Routes](#walking-routes)
|
||||
* [Graceful Shutdown](#graceful-shutdown)
|
||||
* [Middleware](#middleware)
|
||||
* [Testing Handlers](#testing-handlers)
|
||||
* [Full Example](#full-example)
|
||||
|
||||
---
|
||||
|
||||
## Install
|
||||
|
||||
With a [correctly configured](https://golang.org/doc/install#testing) Go toolchain:
|
||||
|
||||
```sh
|
||||
go get -u github.com/gorilla/mux
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
Let's start registering a couple of URL paths and handlers:
|
||||
|
||||
```go
|
||||
func main() {
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/", HomeHandler)
|
||||
r.HandleFunc("/products", ProductsHandler)
|
||||
r.HandleFunc("/articles", ArticlesHandler)
|
||||
http.Handle("/", r)
|
||||
}
|
||||
```
|
||||
|
||||
Here we register three routes mapping URL paths to handlers. This is equivalent to how `http.HandleFunc()` works: if an incoming request URL matches one of the paths, the corresponding handler is called passing (`http.ResponseWriter`, `*http.Request`) as parameters.
|
||||
|
||||
Paths can have variables. They are defined using the format `{name}` or `{name:pattern}`. If a regular expression pattern is not defined, the matched variable will be anything until the next slash. For example:
|
||||
|
||||
```go
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/products/{key}", ProductHandler)
|
||||
r.HandleFunc("/articles/{category}/", ArticlesCategoryHandler)
|
||||
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler)
|
||||
```
|
||||
|
||||
The names are used to create a map of route variables which can be retrieved calling `mux.Vars()`:
|
||||
|
||||
```go
|
||||
func ArticlesCategoryHandler(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, "Category: %v\n", vars["category"])
|
||||
}
|
||||
```
|
||||
|
||||
And this is all you need to know about the basic usage. More advanced options are explained below.
|
||||
|
||||
### Matching Routes
|
||||
|
||||
Routes can also be restricted to a domain or subdomain. Just define a host pattern to be matched. They can also have variables:
|
||||
|
||||
```go
|
||||
r := mux.NewRouter()
|
||||
// Only matches if domain is "www.example.com".
|
||||
r.Host("www.example.com")
|
||||
// Matches a dynamic subdomain.
|
||||
r.Host("{subdomain:[a-z]+}.domain.com")
|
||||
```
|
||||
|
||||
There are several other matchers that can be added. To match path prefixes:
|
||||
|
||||
```go
|
||||
r.PathPrefix("/products/")
|
||||
```
|
||||
|
||||
...or HTTP methods:
|
||||
|
||||
```go
|
||||
r.Methods("GET", "POST")
|
||||
```
|
||||
|
||||
...or URL schemes:
|
||||
|
||||
```go
|
||||
r.Schemes("https")
|
||||
```
|
||||
|
||||
...or header values:
|
||||
|
||||
```go
|
||||
r.Headers("X-Requested-With", "XMLHttpRequest")
|
||||
```
|
||||
|
||||
...or query values:
|
||||
|
||||
```go
|
||||
r.Queries("key", "value")
|
||||
```
|
||||
|
||||
...or to use a custom matcher function:
|
||||
|
||||
```go
|
||||
r.MatcherFunc(func(r *http.Request, rm *RouteMatch) bool {
|
||||
return r.ProtoMajor == 0
|
||||
})
|
||||
```
|
||||
|
||||
...and finally, it is possible to combine several matchers in a single route:
|
||||
|
||||
```go
|
||||
r.HandleFunc("/products", ProductsHandler).
|
||||
Host("www.example.com").
|
||||
Methods("GET").
|
||||
Schemes("http")
|
||||
```
|
||||
|
||||
Routes are tested in the order they were added to the router. If two routes match, the first one wins:
|
||||
|
||||
```go
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/specific", specificHandler)
|
||||
r.PathPrefix("/").Handler(catchAllHandler)
|
||||
```
|
||||
|
||||
Setting the same matching conditions again and again can be boring, so we have a way to group several routes that share the same requirements. We call it "subrouting".
|
||||
|
||||
For example, let's say we have several URLs that should only match when the host is `www.example.com`. Create a route for that host and get a "subrouter" from it:
|
||||
|
||||
```go
|
||||
r := mux.NewRouter()
|
||||
s := r.Host("www.example.com").Subrouter()
|
||||
```
|
||||
|
||||
Then register routes in the subrouter:
|
||||
|
||||
```go
|
||||
s.HandleFunc("/products/", ProductsHandler)
|
||||
s.HandleFunc("/products/{key}", ProductHandler)
|
||||
s.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler)
|
||||
```
|
||||
|
||||
The three URL paths we registered above will only be tested if the domain is `www.example.com`, because the subrouter is tested first. This is not only convenient, but also optimizes request matching. You can create subrouters combining any attribute matchers accepted by a route.
|
||||
|
||||
Subrouters can be used to create domain or path "namespaces": you define subrouters in a central place and then parts of the app can register its paths relatively to a given subrouter.
|
||||
|
||||
There's one more thing about subroutes. When a subrouter has a path prefix, the inner routes use it as base for their paths:
|
||||
|
||||
```go
|
||||
r := mux.NewRouter()
|
||||
s := r.PathPrefix("/products").Subrouter()
|
||||
// "/products/"
|
||||
s.HandleFunc("/", ProductsHandler)
|
||||
// "/products/{key}/"
|
||||
s.HandleFunc("/{key}/", ProductHandler)
|
||||
// "/products/{key}/details"
|
||||
s.HandleFunc("/{key}/details", ProductDetailsHandler)
|
||||
```
|
||||
|
||||
|
||||
### Static Files
|
||||
|
||||
Note that the path provided to `PathPrefix()` represents a "wildcard": calling
|
||||
`PathPrefix("/static/").Handler(...)` means that the handler will be passed any
|
||||
request that matches "/static/\*". This makes it easy to serve static files with mux:
|
||||
|
||||
```go
|
||||
func main() {
|
||||
var dir string
|
||||
|
||||
flag.StringVar(&dir, "dir", ".", "the directory to serve files from. Defaults to the current dir")
|
||||
flag.Parse()
|
||||
r := mux.NewRouter()
|
||||
|
||||
// This will serve files under http://localhost:8000/static/<filename>
|
||||
r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir(dir))))
|
||||
|
||||
srv := &http.Server{
|
||||
Handler: r,
|
||||
Addr: "127.0.0.1:8000",
|
||||
// Good practice: enforce timeouts for servers you create!
|
||||
WriteTimeout: 15 * time.Second,
|
||||
ReadTimeout: 15 * time.Second,
|
||||
}
|
||||
|
||||
log.Fatal(srv.ListenAndServe())
|
||||
}
|
||||
```
|
||||
|
||||
### Registered URLs
|
||||
|
||||
Now let's see how to build registered URLs.
|
||||
|
||||
Routes can be named. All routes that define a name can have their URLs built, or "reversed". We define a name calling `Name()` on a route. For example:
|
||||
|
||||
```go
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
|
||||
Name("article")
|
||||
```
|
||||
|
||||
To build a URL, get the route and call the `URL()` method, passing a sequence of key/value pairs for the route variables. For the previous route, we would do:
|
||||
|
||||
```go
|
||||
url, err := r.Get("article").URL("category", "technology", "id", "42")
|
||||
```
|
||||
|
||||
...and the result will be a `url.URL` with the following path:
|
||||
|
||||
```
|
||||
"/articles/technology/42"
|
||||
```
|
||||
|
||||
This also works for host and query value variables:
|
||||
|
||||
```go
|
||||
r := mux.NewRouter()
|
||||
r.Host("{subdomain}.domain.com").
|
||||
Path("/articles/{category}/{id:[0-9]+}").
|
||||
Queries("filter", "{filter}").
|
||||
HandlerFunc(ArticleHandler).
|
||||
Name("article")
|
||||
|
||||
// url.String() will be "http://news.domain.com/articles/technology/42?filter=gorilla"
|
||||
url, err := r.Get("article").URL("subdomain", "news",
|
||||
"category", "technology",
|
||||
"id", "42",
|
||||
"filter", "gorilla")
|
||||
```
|
||||
|
||||
All variables defined in the route are required, and their values must conform to the corresponding patterns. These requirements guarantee that a generated URL will always match a registered route -- the only exception is for explicitly defined "build-only" routes which never match.
|
||||
|
||||
Regex support also exists for matching Headers within a route. For example, we could do:
|
||||
|
||||
```go
|
||||
r.HeadersRegexp("Content-Type", "application/(text|json)")
|
||||
```
|
||||
|
||||
...and the route will match both requests with a Content-Type of `application/json` as well as `application/text`
|
||||
|
||||
There's also a way to build only the URL host or path for a route: use the methods `URLHost()` or `URLPath()` instead. For the previous route, we would do:
|
||||
|
||||
```go
|
||||
// "http://news.domain.com/"
|
||||
host, err := r.Get("article").URLHost("subdomain", "news")
|
||||
|
||||
// "/articles/technology/42"
|
||||
path, err := r.Get("article").URLPath("category", "technology", "id", "42")
|
||||
```
|
||||
|
||||
And if you use subrouters, host and path defined separately can be built as well:
|
||||
|
||||
```go
|
||||
r := mux.NewRouter()
|
||||
s := r.Host("{subdomain}.domain.com").Subrouter()
|
||||
s.Path("/articles/{category}/{id:[0-9]+}").
|
||||
HandlerFunc(ArticleHandler).
|
||||
Name("article")
|
||||
|
||||
// "http://news.domain.com/articles/technology/42"
|
||||
url, err := r.Get("article").URL("subdomain", "news",
|
||||
"category", "technology",
|
||||
"id", "42")
|
||||
```
|
||||
|
||||
### Walking Routes
|
||||
|
||||
The `Walk` function on `mux.Router` can be used to visit all of the routes that are registered on a router. For example,
|
||||
the following prints all of the registered routes:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
func main() {
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/", handler)
|
||||
r.HandleFunc("/products", handler).Methods("POST")
|
||||
r.HandleFunc("/articles", handler).Methods("GET")
|
||||
r.HandleFunc("/articles/{id}", handler).Methods("GET", "PUT")
|
||||
r.HandleFunc("/authors", handler).Queries("surname", "{surname}")
|
||||
err := r.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
|
||||
pathTemplate, err := route.GetPathTemplate()
|
||||
if err == nil {
|
||||
fmt.Println("ROUTE:", pathTemplate)
|
||||
}
|
||||
pathRegexp, err := route.GetPathRegexp()
|
||||
if err == nil {
|
||||
fmt.Println("Path regexp:", pathRegexp)
|
||||
}
|
||||
queriesTemplates, err := route.GetQueriesTemplates()
|
||||
if err == nil {
|
||||
fmt.Println("Queries templates:", strings.Join(queriesTemplates, ","))
|
||||
}
|
||||
queriesRegexps, err := route.GetQueriesRegexp()
|
||||
if err == nil {
|
||||
fmt.Println("Queries regexps:", strings.Join(queriesRegexps, ","))
|
||||
}
|
||||
methods, err := route.GetMethods()
|
||||
if err == nil {
|
||||
fmt.Println("Methods:", strings.Join(methods, ","))
|
||||
}
|
||||
fmt.Println()
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
http.Handle("/", r)
|
||||
}
|
||||
```
|
||||
|
||||
### Graceful Shutdown
|
||||
|
||||
Go 1.8 introduced the ability to [gracefully shutdown](https://golang.org/doc/go1.8#http_shutdown) a `*http.Server`. Here's how to do that alongside `mux`:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var wait time.Duration
|
||||
flag.DurationVar(&wait, "graceful-timeout", time.Second * 15, "the duration for which the server gracefully wait for existing connections to finish - e.g. 15s or 1m")
|
||||
flag.Parse()
|
||||
|
||||
r := mux.NewRouter()
|
||||
// Add your routes as needed
|
||||
|
||||
srv := &http.Server{
|
||||
Addr: "0.0.0.0:8080",
|
||||
// Good practice to set timeouts to avoid Slowloris attacks.
|
||||
WriteTimeout: time.Second * 15,
|
||||
ReadTimeout: time.Second * 15,
|
||||
IdleTimeout: time.Second * 60,
|
||||
Handler: r, // Pass our instance of gorilla/mux in.
|
||||
}
|
||||
|
||||
// Run our server in a goroutine so that it doesn't block.
|
||||
go func() {
|
||||
if err := srv.ListenAndServe(); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}()
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
// We'll accept graceful shutdowns when quit via SIGINT (Ctrl+C)
|
||||
// SIGKILL, SIGQUIT or SIGTERM (Ctrl+/) will not be caught.
|
||||
signal.Notify(c, os.Interrupt)
|
||||
|
||||
// Block until we receive our signal.
|
||||
<-c
|
||||
|
||||
// Create a deadline to wait for.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), wait)
|
||||
defer cancel()
|
||||
// Doesn't block if no connections, but will otherwise wait
|
||||
// until the timeout deadline.
|
||||
srv.Shutdown(ctx)
|
||||
// Optionally, you could run srv.Shutdown in a goroutine and block on
|
||||
// <-ctx.Done() if your application should wait for other services
|
||||
// to finalize based on context cancellation.
|
||||
log.Println("shutting down")
|
||||
os.Exit(0)
|
||||
}
|
||||
```
|
||||
|
||||
### Middleware
|
||||
|
||||
Mux supports the addition of middlewares to a [Router](https://godoc.org/github.com/gorilla/mux#Router), which are executed in the order they are added if a match is found, including its subrouters.
|
||||
Middlewares are (typically) small pieces of code which take one request, do something with it, and pass it down to another middleware or the final handler. Some common use cases for middleware are request logging, header manipulation, or `ResponseWriter` hijacking.
|
||||
|
||||
Mux middlewares are defined using the de facto standard type:
|
||||
|
||||
```go
|
||||
type MiddlewareFunc func(http.Handler) http.Handler
|
||||
```
|
||||
|
||||
Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed to it, and then calls the handler passed as parameter to the MiddlewareFunc. This takes advantage of closures being able access variables from the context where they are created, while retaining the signature enforced by the receivers.
|
||||
|
||||
A very basic middleware which logs the URI of the request being handled could be written as:
|
||||
|
||||
```go
|
||||
func loggingMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Do stuff here
|
||||
log.Println(r.RequestURI)
|
||||
// Call the next handler, which can be another middleware in the chain, or the final handler.
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
Middlewares can be added to a router using `Router.Use()`:
|
||||
|
||||
```go
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/", handler)
|
||||
r.Use(loggingMiddleware)
|
||||
```
|
||||
|
||||
A more complex authentication middleware, which maps session token to users, could be written as:
|
||||
|
||||
```go
|
||||
// Define our struct
|
||||
type authenticationMiddleware struct {
|
||||
tokenUsers map[string]string
|
||||
}
|
||||
|
||||
// Initialize it somewhere
|
||||
func (amw *authenticationMiddleware) Populate() {
|
||||
amw.tokenUsers["00000000"] = "user0"
|
||||
amw.tokenUsers["aaaaaaaa"] = "userA"
|
||||
amw.tokenUsers["05f717e5"] = "randomUser"
|
||||
amw.tokenUsers["deadbeef"] = "user0"
|
||||
}
|
||||
|
||||
// Middleware function, which will be called for each request
|
||||
func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.Header.Get("X-Session-Token")
|
||||
|
||||
if user, found := amw.tokenUsers[token]; found {
|
||||
// We found the token in our map
|
||||
log.Printf("Authenticated user %s\n", user)
|
||||
// Pass down the request to the next middleware (or final handler)
|
||||
next.ServeHTTP(w, r)
|
||||
} else {
|
||||
// Write an error and stop the handler chain
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
```go
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/", handler)
|
||||
|
||||
amw := authenticationMiddleware{}
|
||||
amw.Populate()
|
||||
|
||||
r.Use(amw.Middleware)
|
||||
```
|
||||
|
||||
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares _should_ write to `ResponseWriter` if they _are_ going to terminate the request, and they _should not_ write to `ResponseWriter` if they _are not_ going to terminate it.
|
||||
|
||||
### Testing Handlers
|
||||
|
||||
Testing handlers in a Go web application is straightforward, and _mux_ doesn't complicate this any further. Given two files: `endpoints.go` and `endpoints_test.go`, here's how we'd test an application using _mux_.
|
||||
|
||||
First, our simple HTTP handler:
|
||||
|
||||
```go
|
||||
// endpoints.go
|
||||
package main
|
||||
|
||||
func HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// A very simple health check.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
// In the future we could report back on the status of our DB, or our cache
|
||||
// (e.g. Redis) by performing a simple PING, and include them in the response.
|
||||
io.WriteString(w, `{"alive": true}`)
|
||||
}
|
||||
|
||||
func main() {
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/health", HealthCheckHandler)
|
||||
|
||||
log.Fatal(http.ListenAndServe("localhost:8080", r))
|
||||
}
|
||||
```
|
||||
|
||||
Our test code:
|
||||
|
||||
```go
|
||||
// endpoints_test.go
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHealthCheckHandler(t *testing.T) {
|
||||
// Create a request to pass to our handler. We don't have any query parameters for now, so we'll
|
||||
// pass 'nil' as the third parameter.
|
||||
req, err := http.NewRequest("GET", "/health", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
|
||||
rr := httptest.NewRecorder()
|
||||
handler := http.HandlerFunc(HealthCheckHandler)
|
||||
|
||||
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
|
||||
// directly and pass in our Request and ResponseRecorder.
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// Check the status code is what we expect.
|
||||
if status := rr.Code; status != http.StatusOK {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v",
|
||||
status, http.StatusOK)
|
||||
}
|
||||
|
||||
// Check the response body is what we expect.
|
||||
expected := `{"alive": true}`
|
||||
if rr.Body.String() != expected {
|
||||
t.Errorf("handler returned unexpected body: got %v want %v",
|
||||
rr.Body.String(), expected)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
In the case that our routes have [variables](#examples), we can pass those in the request. We could write
|
||||
[table-driven tests](https://dave.cheney.net/2013/06/09/writing-table-driven-tests-in-go) to test multiple
|
||||
possible route variables as needed.
|
||||
|
||||
```go
|
||||
// endpoints.go
|
||||
func main() {
|
||||
r := mux.NewRouter()
|
||||
// A route with a route variable:
|
||||
r.HandleFunc("/metrics/{type}", MetricsHandler)
|
||||
|
||||
log.Fatal(http.ListenAndServe("localhost:8080", r))
|
||||
}
|
||||
```
|
||||
|
||||
Our test file, with a table-driven test of `routeVariables`:
|
||||
|
||||
```go
|
||||
// endpoints_test.go
|
||||
func TestMetricsHandler(t *testing.T) {
|
||||
tt := []struct{
|
||||
routeVariable string
|
||||
shouldPass bool
|
||||
}{
|
||||
{"goroutines", true},
|
||||
{"heap", true},
|
||||
{"counters", true},
|
||||
{"queries", true},
|
||||
{"adhadaeqm3k", false},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
path := fmt.Sprintf("/metrics/%s", tc.routeVariable)
|
||||
req, err := http.NewRequest("GET", path, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Need to create a router that we can pass the request through so that the vars will be added to the context
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/metrics/{type}", MetricsHandler)
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
// In this case, our MetricsHandler returns a non-200 response
|
||||
// for a route variable it doesn't know about.
|
||||
if rr.Code == http.StatusOK && !tc.shouldPass {
|
||||
t.Errorf("handler should have failed on routeVariable %s: got %v want %v",
|
||||
tc.routeVariable, rr.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Full Example
|
||||
|
||||
Here's a complete, runnable example of a small `mux` based server:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"log"
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
func YourHandler(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("Gorilla!\n"))
|
||||
}
|
||||
|
||||
func main() {
|
||||
r := mux.NewRouter()
|
||||
// Routes consist of a path and a handler function.
|
||||
r.HandleFunc("/", YourHandler)
|
||||
|
||||
// Bind to a port and pass our router in
|
||||
log.Fatal(http.ListenAndServe(":8000", r))
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
BSD licensed. See the LICENSE file for details.
|
26
vendor/github.com/gorilla/mux/context_gorilla.go
generated
vendored
26
vendor/github.com/gorilla/mux/context_gorilla.go
generated
vendored
@ -1,26 +0,0 @@
|
||||
// +build !go1.7
|
||||
|
||||
package mux
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/context"
|
||||
)
|
||||
|
||||
func contextGet(r *http.Request, key interface{}) interface{} {
|
||||
return context.Get(r, key)
|
||||
}
|
||||
|
||||
func contextSet(r *http.Request, key, val interface{}) *http.Request {
|
||||
if val == nil {
|
||||
return r
|
||||
}
|
||||
|
||||
context.Set(r, key, val)
|
||||
return r
|
||||
}
|
||||
|
||||
func contextClear(r *http.Request) {
|
||||
context.Clear(r)
|
||||
}
|
24
vendor/github.com/gorilla/mux/context_native.go
generated
vendored
24
vendor/github.com/gorilla/mux/context_native.go
generated
vendored
@ -1,24 +0,0 @@
|
||||
// +build go1.7
|
||||
|
||||
package mux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func contextGet(r *http.Request, key interface{}) interface{} {
|
||||
return r.Context().Value(key)
|
||||
}
|
||||
|
||||
func contextSet(r *http.Request, key, val interface{}) *http.Request {
|
||||
if val == nil {
|
||||
return r
|
||||
}
|
||||
|
||||
return r.WithContext(context.WithValue(r.Context(), key, val))
|
||||
}
|
||||
|
||||
func contextClear(r *http.Request) {
|
||||
return
|
||||
}
|
306
vendor/github.com/gorilla/mux/doc.go
generated
vendored
306
vendor/github.com/gorilla/mux/doc.go
generated
vendored
@ -1,306 +0,0 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package mux implements a request router and dispatcher.
|
||||
|
||||
The name mux stands for "HTTP request multiplexer". Like the standard
|
||||
http.ServeMux, mux.Router matches incoming requests against a list of
|
||||
registered routes and calls a handler for the route that matches the URL
|
||||
or other conditions. The main features are:
|
||||
|
||||
* Requests can be matched based on URL host, path, path prefix, schemes,
|
||||
header and query values, HTTP methods or using custom matchers.
|
||||
* URL hosts, paths and query values can have variables with an optional
|
||||
regular expression.
|
||||
* Registered URLs can be built, or "reversed", which helps maintaining
|
||||
references to resources.
|
||||
* Routes can be used as subrouters: nested routes are only tested if the
|
||||
parent route matches. This is useful to define groups of routes that
|
||||
share common conditions like a host, a path prefix or other repeated
|
||||
attributes. As a bonus, this optimizes request matching.
|
||||
* It implements the http.Handler interface so it is compatible with the
|
||||
standard http.ServeMux.
|
||||
|
||||
Let's start registering a couple of URL paths and handlers:
|
||||
|
||||
func main() {
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/", HomeHandler)
|
||||
r.HandleFunc("/products", ProductsHandler)
|
||||
r.HandleFunc("/articles", ArticlesHandler)
|
||||
http.Handle("/", r)
|
||||
}
|
||||
|
||||
Here we register three routes mapping URL paths to handlers. This is
|
||||
equivalent to how http.HandleFunc() works: if an incoming request URL matches
|
||||
one of the paths, the corresponding handler is called passing
|
||||
(http.ResponseWriter, *http.Request) as parameters.
|
||||
|
||||
Paths can have variables. They are defined using the format {name} or
|
||||
{name:pattern}. If a regular expression pattern is not defined, the matched
|
||||
variable will be anything until the next slash. For example:
|
||||
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/products/{key}", ProductHandler)
|
||||
r.HandleFunc("/articles/{category}/", ArticlesCategoryHandler)
|
||||
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler)
|
||||
|
||||
Groups can be used inside patterns, as long as they are non-capturing (?:re). For example:
|
||||
|
||||
r.HandleFunc("/articles/{category}/{sort:(?:asc|desc|new)}", ArticlesCategoryHandler)
|
||||
|
||||
The names are used to create a map of route variables which can be retrieved
|
||||
calling mux.Vars():
|
||||
|
||||
vars := mux.Vars(request)
|
||||
category := vars["category"]
|
||||
|
||||
Note that if any capturing groups are present, mux will panic() during parsing. To prevent
|
||||
this, convert any capturing groups to non-capturing, e.g. change "/{sort:(asc|desc)}" to
|
||||
"/{sort:(?:asc|desc)}". This is a change from prior versions which behaved unpredictably
|
||||
when capturing groups were present.
|
||||
|
||||
And this is all you need to know about the basic usage. More advanced options
|
||||
are explained below.
|
||||
|
||||
Routes can also be restricted to a domain or subdomain. Just define a host
|
||||
pattern to be matched. They can also have variables:
|
||||
|
||||
r := mux.NewRouter()
|
||||
// Only matches if domain is "www.example.com".
|
||||
r.Host("www.example.com")
|
||||
// Matches a dynamic subdomain.
|
||||
r.Host("{subdomain:[a-z]+}.domain.com")
|
||||
|
||||
There are several other matchers that can be added. To match path prefixes:
|
||||
|
||||
r.PathPrefix("/products/")
|
||||
|
||||
...or HTTP methods:
|
||||
|
||||
r.Methods("GET", "POST")
|
||||
|
||||
...or URL schemes:
|
||||
|
||||
r.Schemes("https")
|
||||
|
||||
...or header values:
|
||||
|
||||
r.Headers("X-Requested-With", "XMLHttpRequest")
|
||||
|
||||
...or query values:
|
||||
|
||||
r.Queries("key", "value")
|
||||
|
||||
...or to use a custom matcher function:
|
||||
|
||||
r.MatcherFunc(func(r *http.Request, rm *RouteMatch) bool {
|
||||
return r.ProtoMajor == 0
|
||||
})
|
||||
|
||||
...and finally, it is possible to combine several matchers in a single route:
|
||||
|
||||
r.HandleFunc("/products", ProductsHandler).
|
||||
Host("www.example.com").
|
||||
Methods("GET").
|
||||
Schemes("http")
|
||||
|
||||
Setting the same matching conditions again and again can be boring, so we have
|
||||
a way to group several routes that share the same requirements.
|
||||
We call it "subrouting".
|
||||
|
||||
For example, let's say we have several URLs that should only match when the
|
||||
host is "www.example.com". Create a route for that host and get a "subrouter"
|
||||
from it:
|
||||
|
||||
r := mux.NewRouter()
|
||||
s := r.Host("www.example.com").Subrouter()
|
||||
|
||||
Then register routes in the subrouter:
|
||||
|
||||
s.HandleFunc("/products/", ProductsHandler)
|
||||
s.HandleFunc("/products/{key}", ProductHandler)
|
||||
s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler)
|
||||
|
||||
The three URL paths we registered above will only be tested if the domain is
|
||||
"www.example.com", because the subrouter is tested first. This is not
|
||||
only convenient, but also optimizes request matching. You can create
|
||||
subrouters combining any attribute matchers accepted by a route.
|
||||
|
||||
Subrouters can be used to create domain or path "namespaces": you define
|
||||
subrouters in a central place and then parts of the app can register its
|
||||
paths relatively to a given subrouter.
|
||||
|
||||
There's one more thing about subroutes. When a subrouter has a path prefix,
|
||||
the inner routes use it as base for their paths:
|
||||
|
||||
r := mux.NewRouter()
|
||||
s := r.PathPrefix("/products").Subrouter()
|
||||
// "/products/"
|
||||
s.HandleFunc("/", ProductsHandler)
|
||||
// "/products/{key}/"
|
||||
s.HandleFunc("/{key}/", ProductHandler)
|
||||
// "/products/{key}/details"
|
||||
s.HandleFunc("/{key}/details", ProductDetailsHandler)
|
||||
|
||||
Note that the path provided to PathPrefix() represents a "wildcard": calling
|
||||
PathPrefix("/static/").Handler(...) means that the handler will be passed any
|
||||
request that matches "/static/*". This makes it easy to serve static files with mux:
|
||||
|
||||
func main() {
|
||||
var dir string
|
||||
|
||||
flag.StringVar(&dir, "dir", ".", "the directory to serve files from. Defaults to the current dir")
|
||||
flag.Parse()
|
||||
r := mux.NewRouter()
|
||||
|
||||
// This will serve files under http://localhost:8000/static/<filename>
|
||||
r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir(dir))))
|
||||
|
||||
srv := &http.Server{
|
||||
Handler: r,
|
||||
Addr: "127.0.0.1:8000",
|
||||
// Good practice: enforce timeouts for servers you create!
|
||||
WriteTimeout: 15 * time.Second,
|
||||
ReadTimeout: 15 * time.Second,
|
||||
}
|
||||
|
||||
log.Fatal(srv.ListenAndServe())
|
||||
}
|
||||
|
||||
Now let's see how to build registered URLs.
|
||||
|
||||
Routes can be named. All routes that define a name can have their URLs built,
|
||||
or "reversed". We define a name calling Name() on a route. For example:
|
||||
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
|
||||
Name("article")
|
||||
|
||||
To build a URL, get the route and call the URL() method, passing a sequence of
|
||||
key/value pairs for the route variables. For the previous route, we would do:
|
||||
|
||||
url, err := r.Get("article").URL("category", "technology", "id", "42")
|
||||
|
||||
...and the result will be a url.URL with the following path:
|
||||
|
||||
"/articles/technology/42"
|
||||
|
||||
This also works for host and query value variables:
|
||||
|
||||
r := mux.NewRouter()
|
||||
r.Host("{subdomain}.domain.com").
|
||||
Path("/articles/{category}/{id:[0-9]+}").
|
||||
Queries("filter", "{filter}").
|
||||
HandlerFunc(ArticleHandler).
|
||||
Name("article")
|
||||
|
||||
// url.String() will be "http://news.domain.com/articles/technology/42?filter=gorilla"
|
||||
url, err := r.Get("article").URL("subdomain", "news",
|
||||
"category", "technology",
|
||||
"id", "42",
|
||||
"filter", "gorilla")
|
||||
|
||||
All variables defined in the route are required, and their values must
|
||||
conform to the corresponding patterns. These requirements guarantee that a
|
||||
generated URL will always match a registered route -- the only exception is
|
||||
for explicitly defined "build-only" routes which never match.
|
||||
|
||||
Regex support also exists for matching Headers within a route. For example, we could do:
|
||||
|
||||
r.HeadersRegexp("Content-Type", "application/(text|json)")
|
||||
|
||||
...and the route will match both requests with a Content-Type of `application/json` as well as
|
||||
`application/text`
|
||||
|
||||
There's also a way to build only the URL host or path for a route:
|
||||
use the methods URLHost() or URLPath() instead. For the previous route,
|
||||
we would do:
|
||||
|
||||
// "http://news.domain.com/"
|
||||
host, err := r.Get("article").URLHost("subdomain", "news")
|
||||
|
||||
// "/articles/technology/42"
|
||||
path, err := r.Get("article").URLPath("category", "technology", "id", "42")
|
||||
|
||||
And if you use subrouters, host and path defined separately can be built
|
||||
as well:
|
||||
|
||||
r := mux.NewRouter()
|
||||
s := r.Host("{subdomain}.domain.com").Subrouter()
|
||||
s.Path("/articles/{category}/{id:[0-9]+}").
|
||||
HandlerFunc(ArticleHandler).
|
||||
Name("article")
|
||||
|
||||
// "http://news.domain.com/articles/technology/42"
|
||||
url, err := r.Get("article").URL("subdomain", "news",
|
||||
"category", "technology",
|
||||
"id", "42")
|
||||
|
||||
Mux supports the addition of middlewares to a Router, which are executed in the order they are added if a match is found, including its subrouters. Middlewares are (typically) small pieces of code which take one request, do something with it, and pass it down to another middleware or the final handler. Some common use cases for middleware are request logging, header manipulation, or ResponseWriter hijacking.
|
||||
|
||||
type MiddlewareFunc func(http.Handler) http.Handler
|
||||
|
||||
Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed to it, and then calls the handler passed as parameter to the MiddlewareFunc (closures can access variables from the context where they are created).
|
||||
|
||||
A very basic middleware which logs the URI of the request being handled could be written as:
|
||||
|
||||
func simpleMw(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Do stuff here
|
||||
log.Println(r.RequestURI)
|
||||
// Call the next handler, which can be another middleware in the chain, or the final handler.
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
Middlewares can be added to a router using `Router.Use()`:
|
||||
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/", handler)
|
||||
r.Use(simpleMw)
|
||||
|
||||
A more complex authentication middleware, which maps session token to users, could be written as:
|
||||
|
||||
// Define our struct
|
||||
type authenticationMiddleware struct {
|
||||
tokenUsers map[string]string
|
||||
}
|
||||
|
||||
// Initialize it somewhere
|
||||
func (amw *authenticationMiddleware) Populate() {
|
||||
amw.tokenUsers["00000000"] = "user0"
|
||||
amw.tokenUsers["aaaaaaaa"] = "userA"
|
||||
amw.tokenUsers["05f717e5"] = "randomUser"
|
||||
amw.tokenUsers["deadbeef"] = "user0"
|
||||
}
|
||||
|
||||
// Middleware function, which will be called for each request
|
||||
func (amw *authenticationMiddleware) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := r.Header.Get("X-Session-Token")
|
||||
|
||||
if user, found := amw.tokenUsers[token]; found {
|
||||
// We found the token in our map
|
||||
log.Printf("Authenticated user %s\n", user)
|
||||
next.ServeHTTP(w, r)
|
||||
} else {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
r := mux.NewRouter()
|
||||
r.HandleFunc("/", handler)
|
||||
|
||||
amw := authenticationMiddleware{}
|
||||
amw.Populate()
|
||||
|
||||
r.Use(amw.Middleware)
|
||||
|
||||
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to.
|
||||
|
||||
*/
|
||||
package mux
|
72
vendor/github.com/gorilla/mux/middleware.go
generated
vendored
72
vendor/github.com/gorilla/mux/middleware.go
generated
vendored
@ -1,72 +0,0 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// MiddlewareFunc is a function which receives an http.Handler and returns another http.Handler.
|
||||
// Typically, the returned handler is a closure which does something with the http.ResponseWriter and http.Request passed
|
||||
// to it, and then calls the handler passed as parameter to the MiddlewareFunc.
|
||||
type MiddlewareFunc func(http.Handler) http.Handler
|
||||
|
||||
// middleware interface is anything which implements a MiddlewareFunc named Middleware.
|
||||
type middleware interface {
|
||||
Middleware(handler http.Handler) http.Handler
|
||||
}
|
||||
|
||||
// Middleware allows MiddlewareFunc to implement the middleware interface.
|
||||
func (mw MiddlewareFunc) Middleware(handler http.Handler) http.Handler {
|
||||
return mw(handler)
|
||||
}
|
||||
|
||||
// Use appends a MiddlewareFunc to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router.
|
||||
func (r *Router) Use(mwf ...MiddlewareFunc) {
|
||||
for _, fn := range mwf {
|
||||
r.middlewares = append(r.middlewares, fn)
|
||||
}
|
||||
}
|
||||
|
||||
// useInterface appends a middleware to the chain. Middleware can be used to intercept or otherwise modify requests and/or responses, and are executed in the order that they are applied to the Router.
|
||||
func (r *Router) useInterface(mw middleware) {
|
||||
r.middlewares = append(r.middlewares, mw)
|
||||
}
|
||||
|
||||
// CORSMethodMiddleware sets the Access-Control-Allow-Methods response header
|
||||
// on a request, by matching routes based only on paths. It also handles
|
||||
// OPTIONS requests, by settings Access-Control-Allow-Methods, and then
|
||||
// returning without calling the next http handler.
|
||||
func CORSMethodMiddleware(r *Router) MiddlewareFunc {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
var allMethods []string
|
||||
|
||||
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
|
||||
for _, m := range route.matchers {
|
||||
if _, ok := m.(*routeRegexp); ok {
|
||||
if m.Match(req, &RouteMatch{}) {
|
||||
methods, err := route.GetMethods()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allMethods = append(allMethods, methods...)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allMethods, "OPTIONS"), ","))
|
||||
|
||||
if req.Method == "OPTIONS" {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, req)
|
||||
})
|
||||
}
|
||||
}
|
588
vendor/github.com/gorilla/mux/mux.go
generated
vendored
588
vendor/github.com/gorilla/mux/mux.go
generated
vendored
@ -1,588 +0,0 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mux
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrMethodMismatch is returned when the method in the request does not match
|
||||
// the method defined against the route.
|
||||
ErrMethodMismatch = errors.New("method is not allowed")
|
||||
// ErrNotFound is returned when no route match is found.
|
||||
ErrNotFound = errors.New("no matching route was found")
|
||||
)
|
||||
|
||||
// NewRouter returns a new router instance.
|
||||
func NewRouter() *Router {
|
||||
return &Router{namedRoutes: make(map[string]*Route), KeepContext: false}
|
||||
}
|
||||
|
||||
// Router registers routes to be matched and dispatches a handler.
|
||||
//
|
||||
// It implements the http.Handler interface, so it can be registered to serve
|
||||
// requests:
|
||||
//
|
||||
// var router = mux.NewRouter()
|
||||
//
|
||||
// func main() {
|
||||
// http.Handle("/", router)
|
||||
// }
|
||||
//
|
||||
// Or, for Google App Engine, register it in a init() function:
|
||||
//
|
||||
// func init() {
|
||||
// http.Handle("/", router)
|
||||
// }
|
||||
//
|
||||
// This will send all incoming requests to the router.
|
||||
type Router struct {
|
||||
// Configurable Handler to be used when no route matches.
|
||||
NotFoundHandler http.Handler
|
||||
|
||||
// Configurable Handler to be used when the request method does not match the route.
|
||||
MethodNotAllowedHandler http.Handler
|
||||
|
||||
// Parent route, if this is a subrouter.
|
||||
parent parentRoute
|
||||
// Routes to be matched, in order.
|
||||
routes []*Route
|
||||
// Routes by name for URL building.
|
||||
namedRoutes map[string]*Route
|
||||
// See Router.StrictSlash(). This defines the flag for new routes.
|
||||
strictSlash bool
|
||||
// See Router.SkipClean(). This defines the flag for new routes.
|
||||
skipClean bool
|
||||
// If true, do not clear the request context after handling the request.
|
||||
// This has no effect when go1.7+ is used, since the context is stored
|
||||
// on the request itself.
|
||||
KeepContext bool
|
||||
// see Router.UseEncodedPath(). This defines a flag for all routes.
|
||||
useEncodedPath bool
|
||||
// Slice of middlewares to be called after a match is found
|
||||
middlewares []middleware
|
||||
}
|
||||
|
||||
// Match attempts to match the given request against the router's registered routes.
|
||||
//
|
||||
// If the request matches a route of this router or one of its subrouters the Route,
|
||||
// Handler, and Vars fields of the the match argument are filled and this function
|
||||
// returns true.
|
||||
//
|
||||
// If the request does not match any of this router's or its subrouters' routes
|
||||
// then this function returns false. If available, a reason for the match failure
|
||||
// will be filled in the match argument's MatchErr field. If the match failure type
|
||||
// (eg: not found) has a registered handler, the handler is assigned to the Handler
|
||||
// field of the match argument.
|
||||
func (r *Router) Match(req *http.Request, match *RouteMatch) bool {
|
||||
for _, route := range r.routes {
|
||||
if route.Match(req, match) {
|
||||
// Build middleware chain if no error was found
|
||||
if match.MatchErr == nil {
|
||||
for i := len(r.middlewares) - 1; i >= 0; i-- {
|
||||
match.Handler = r.middlewares[i].Middleware(match.Handler)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if match.MatchErr == ErrMethodMismatch {
|
||||
if r.MethodNotAllowedHandler != nil {
|
||||
match.Handler = r.MethodNotAllowedHandler
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Closest match for a router (includes sub-routers)
|
||||
if r.NotFoundHandler != nil {
|
||||
match.Handler = r.NotFoundHandler
|
||||
match.MatchErr = ErrNotFound
|
||||
return true
|
||||
}
|
||||
|
||||
match.MatchErr = ErrNotFound
|
||||
return false
|
||||
}
|
||||
|
||||
// ServeHTTP dispatches the handler registered in the matched route.
|
||||
//
|
||||
// When there is a match, the route variables can be retrieved calling
|
||||
// mux.Vars(request).
|
||||
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
if !r.skipClean {
|
||||
path := req.URL.Path
|
||||
if r.useEncodedPath {
|
||||
path = req.URL.EscapedPath()
|
||||
}
|
||||
// Clean path to canonical form and redirect.
|
||||
if p := cleanPath(path); p != path {
|
||||
|
||||
// Added 3 lines (Philip Schlump) - It was dropping the query string and #whatever from query.
|
||||
// This matches with fix in go 1.2 r.c. 4 for same problem. Go Issue:
|
||||
// http://code.google.com/p/go/issues/detail?id=5252
|
||||
url := *req.URL
|
||||
url.Path = p
|
||||
p = url.String()
|
||||
|
||||
w.Header().Set("Location", p)
|
||||
w.WriteHeader(http.StatusMovedPermanently)
|
||||
return
|
||||
}
|
||||
}
|
||||
var match RouteMatch
|
||||
var handler http.Handler
|
||||
if r.Match(req, &match) {
|
||||
handler = match.Handler
|
||||
req = setVars(req, match.Vars)
|
||||
req = setCurrentRoute(req, match.Route)
|
||||
}
|
||||
|
||||
if handler == nil && match.MatchErr == ErrMethodMismatch {
|
||||
handler = methodNotAllowedHandler()
|
||||
}
|
||||
|
||||
if handler == nil {
|
||||
handler = http.NotFoundHandler()
|
||||
}
|
||||
|
||||
if !r.KeepContext {
|
||||
defer contextClear(req)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
// Get returns a route registered with the given name.
|
||||
func (r *Router) Get(name string) *Route {
|
||||
return r.getNamedRoutes()[name]
|
||||
}
|
||||
|
||||
// GetRoute returns a route registered with the given name. This method
|
||||
// was renamed to Get() and remains here for backwards compatibility.
|
||||
func (r *Router) GetRoute(name string) *Route {
|
||||
return r.getNamedRoutes()[name]
|
||||
}
|
||||
|
||||
// StrictSlash defines the trailing slash behavior for new routes. The initial
|
||||
// value is false.
|
||||
//
|
||||
// When true, if the route path is "/path/", accessing "/path" will perform a redirect
|
||||
// to the former and vice versa. In other words, your application will always
|
||||
// see the path as specified in the route.
|
||||
//
|
||||
// When false, if the route path is "/path", accessing "/path/" will not match
|
||||
// this route and vice versa.
|
||||
//
|
||||
// The re-direct is a HTTP 301 (Moved Permanently). Note that when this is set for
|
||||
// routes with a non-idempotent method (e.g. POST, PUT), the subsequent re-directed
|
||||
// request will be made as a GET by most clients. Use middleware or client settings
|
||||
// to modify this behaviour as needed.
|
||||
//
|
||||
// Special case: when a route sets a path prefix using the PathPrefix() method,
|
||||
// strict slash is ignored for that route because the redirect behavior can't
|
||||
// be determined from a prefix alone. However, any subrouters created from that
|
||||
// route inherit the original StrictSlash setting.
|
||||
func (r *Router) StrictSlash(value bool) *Router {
|
||||
r.strictSlash = value
|
||||
return r
|
||||
}
|
||||
|
||||
// SkipClean defines the path cleaning behaviour for new routes. The initial
|
||||
// value is false. Users should be careful about which routes are not cleaned
|
||||
//
|
||||
// When true, if the route path is "/path//to", it will remain with the double
|
||||
// slash. This is helpful if you have a route like: /fetch/http://xkcd.com/534/
|
||||
//
|
||||
// When false, the path will be cleaned, so /fetch/http://xkcd.com/534/ will
|
||||
// become /fetch/http/xkcd.com/534
|
||||
func (r *Router) SkipClean(value bool) *Router {
|
||||
r.skipClean = value
|
||||
return r
|
||||
}
|
||||
|
||||
// UseEncodedPath tells the router to match the encoded original path
|
||||
// to the routes.
|
||||
// For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to".
|
||||
//
|
||||
// If not called, the router will match the unencoded path to the routes.
|
||||
// For eg. "/path/foo%2Fbar/to" will match the path "/path/foo/bar/to"
|
||||
func (r *Router) UseEncodedPath() *Router {
|
||||
r.useEncodedPath = true
|
||||
return r
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// parentRoute
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (r *Router) getBuildScheme() string {
|
||||
if r.parent != nil {
|
||||
return r.parent.getBuildScheme()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getNamedRoutes returns the map where named routes are registered.
|
||||
func (r *Router) getNamedRoutes() map[string]*Route {
|
||||
if r.namedRoutes == nil {
|
||||
if r.parent != nil {
|
||||
r.namedRoutes = r.parent.getNamedRoutes()
|
||||
} else {
|
||||
r.namedRoutes = make(map[string]*Route)
|
||||
}
|
||||
}
|
||||
return r.namedRoutes
|
||||
}
|
||||
|
||||
// getRegexpGroup returns regexp definitions from the parent route, if any.
|
||||
func (r *Router) getRegexpGroup() *routeRegexpGroup {
|
||||
if r.parent != nil {
|
||||
return r.parent.getRegexpGroup()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) buildVars(m map[string]string) map[string]string {
|
||||
if r.parent != nil {
|
||||
m = r.parent.buildVars(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Route factories
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// NewRoute registers an empty route.
|
||||
func (r *Router) NewRoute() *Route {
|
||||
route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath}
|
||||
r.routes = append(r.routes, route)
|
||||
return route
|
||||
}
|
||||
|
||||
// Handle registers a new route with a matcher for the URL path.
|
||||
// See Route.Path() and Route.Handler().
|
||||
func (r *Router) Handle(path string, handler http.Handler) *Route {
|
||||
return r.NewRoute().Path(path).Handler(handler)
|
||||
}
|
||||
|
||||
// HandleFunc registers a new route with a matcher for the URL path.
|
||||
// See Route.Path() and Route.HandlerFunc().
|
||||
func (r *Router) HandleFunc(path string, f func(http.ResponseWriter,
|
||||
*http.Request)) *Route {
|
||||
return r.NewRoute().Path(path).HandlerFunc(f)
|
||||
}
|
||||
|
||||
// Headers registers a new route with a matcher for request header values.
|
||||
// See Route.Headers().
|
||||
func (r *Router) Headers(pairs ...string) *Route {
|
||||
return r.NewRoute().Headers(pairs...)
|
||||
}
|
||||
|
||||
// Host registers a new route with a matcher for the URL host.
|
||||
// See Route.Host().
|
||||
func (r *Router) Host(tpl string) *Route {
|
||||
return r.NewRoute().Host(tpl)
|
||||
}
|
||||
|
||||
// MatcherFunc registers a new route with a custom matcher function.
|
||||
// See Route.MatcherFunc().
|
||||
func (r *Router) MatcherFunc(f MatcherFunc) *Route {
|
||||
return r.NewRoute().MatcherFunc(f)
|
||||
}
|
||||
|
||||
// Methods registers a new route with a matcher for HTTP methods.
|
||||
// See Route.Methods().
|
||||
func (r *Router) Methods(methods ...string) *Route {
|
||||
return r.NewRoute().Methods(methods...)
|
||||
}
|
||||
|
||||
// Path registers a new route with a matcher for the URL path.
|
||||
// See Route.Path().
|
||||
func (r *Router) Path(tpl string) *Route {
|
||||
return r.NewRoute().Path(tpl)
|
||||
}
|
||||
|
||||
// PathPrefix registers a new route with a matcher for the URL path prefix.
|
||||
// See Route.PathPrefix().
|
||||
func (r *Router) PathPrefix(tpl string) *Route {
|
||||
return r.NewRoute().PathPrefix(tpl)
|
||||
}
|
||||
|
||||
// Queries registers a new route with a matcher for URL query values.
|
||||
// See Route.Queries().
|
||||
func (r *Router) Queries(pairs ...string) *Route {
|
||||
return r.NewRoute().Queries(pairs...)
|
||||
}
|
||||
|
||||
// Schemes registers a new route with a matcher for URL schemes.
|
||||
// See Route.Schemes().
|
||||
func (r *Router) Schemes(schemes ...string) *Route {
|
||||
return r.NewRoute().Schemes(schemes...)
|
||||
}
|
||||
|
||||
// BuildVarsFunc registers a new route with a custom function for modifying
|
||||
// route variables before building a URL.
|
||||
func (r *Router) BuildVarsFunc(f BuildVarsFunc) *Route {
|
||||
return r.NewRoute().BuildVarsFunc(f)
|
||||
}
|
||||
|
||||
// Walk walks the router and all its sub-routers, calling walkFn for each route
|
||||
// in the tree. The routes are walked in the order they were added. Sub-routers
|
||||
// are explored depth-first.
|
||||
func (r *Router) Walk(walkFn WalkFunc) error {
|
||||
return r.walk(walkFn, []*Route{})
|
||||
}
|
||||
|
||||
// SkipRouter is used as a return value from WalkFuncs to indicate that the
|
||||
// router that walk is about to descend down to should be skipped.
|
||||
var SkipRouter = errors.New("skip this router")
|
||||
|
||||
// WalkFunc is the type of the function called for each route visited by Walk.
|
||||
// At every invocation, it is given the current route, and the current router,
|
||||
// and a list of ancestor routes that lead to the current route.
|
||||
type WalkFunc func(route *Route, router *Router, ancestors []*Route) error
|
||||
|
||||
func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error {
|
||||
for _, t := range r.routes {
|
||||
err := walkFn(t, r, ancestors)
|
||||
if err == SkipRouter {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, sr := range t.matchers {
|
||||
if h, ok := sr.(*Router); ok {
|
||||
ancestors = append(ancestors, t)
|
||||
err := h.walk(walkFn, ancestors)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ancestors = ancestors[:len(ancestors)-1]
|
||||
}
|
||||
}
|
||||
if h, ok := t.handler.(*Router); ok {
|
||||
ancestors = append(ancestors, t)
|
||||
err := h.walk(walkFn, ancestors)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ancestors = ancestors[:len(ancestors)-1]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Context
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// RouteMatch stores information about a matched route.
|
||||
type RouteMatch struct {
|
||||
Route *Route
|
||||
Handler http.Handler
|
||||
Vars map[string]string
|
||||
|
||||
// MatchErr is set to appropriate matching error
|
||||
// It is set to ErrMethodMismatch if there is a mismatch in
|
||||
// the request method and route method
|
||||
MatchErr error
|
||||
}
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
varsKey contextKey = iota
|
||||
routeKey
|
||||
)
|
||||
|
||||
// Vars returns the route variables for the current request, if any.
|
||||
func Vars(r *http.Request) map[string]string {
|
||||
if rv := contextGet(r, varsKey); rv != nil {
|
||||
return rv.(map[string]string)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CurrentRoute returns the matched route for the current request, if any.
|
||||
// This only works when called inside the handler of the matched route
|
||||
// because the matched route is stored in the request context which is cleared
|
||||
// after the handler returns, unless the KeepContext option is set on the
|
||||
// Router.
|
||||
func CurrentRoute(r *http.Request) *Route {
|
||||
if rv := contextGet(r, routeKey); rv != nil {
|
||||
return rv.(*Route)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setVars(r *http.Request, val interface{}) *http.Request {
|
||||
return contextSet(r, varsKey, val)
|
||||
}
|
||||
|
||||
func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
|
||||
return contextSet(r, routeKey, val)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// cleanPath returns the canonical path for p, eliminating . and .. elements.
|
||||
// Borrowed from the net/http package.
|
||||
func cleanPath(p string) string {
|
||||
if p == "" {
|
||||
return "/"
|
||||
}
|
||||
if p[0] != '/' {
|
||||
p = "/" + p
|
||||
}
|
||||
np := path.Clean(p)
|
||||
// path.Clean removes trailing slash except for root;
|
||||
// put the trailing slash back if necessary.
|
||||
if p[len(p)-1] == '/' && np != "/" {
|
||||
np += "/"
|
||||
}
|
||||
|
||||
return np
|
||||
}
|
||||
|
||||
// uniqueVars returns an error if two slices contain duplicated strings.
|
||||
func uniqueVars(s1, s2 []string) error {
|
||||
for _, v1 := range s1 {
|
||||
for _, v2 := range s2 {
|
||||
if v1 == v2 {
|
||||
return fmt.Errorf("mux: duplicated route variable %q", v2)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkPairs returns the count of strings passed in, and an error if
|
||||
// the count is not an even number.
|
||||
func checkPairs(pairs ...string) (int, error) {
|
||||
length := len(pairs)
|
||||
if length%2 != 0 {
|
||||
return length, fmt.Errorf(
|
||||
"mux: number of parameters must be multiple of 2, got %v", pairs)
|
||||
}
|
||||
return length, nil
|
||||
}
|
||||
|
||||
// mapFromPairsToString converts variadic string parameters to a
|
||||
// string to string map.
|
||||
func mapFromPairsToString(pairs ...string) (map[string]string, error) {
|
||||
length, err := checkPairs(pairs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := make(map[string]string, length/2)
|
||||
for i := 0; i < length; i += 2 {
|
||||
m[pairs[i]] = pairs[i+1]
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// mapFromPairsToRegex converts variadic string parameters to a
|
||||
// string to regex map.
|
||||
func mapFromPairsToRegex(pairs ...string) (map[string]*regexp.Regexp, error) {
|
||||
length, err := checkPairs(pairs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := make(map[string]*regexp.Regexp, length/2)
|
||||
for i := 0; i < length; i += 2 {
|
||||
regex, err := regexp.Compile(pairs[i+1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m[pairs[i]] = regex
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// matchInArray returns true if the given string value is in the array.
|
||||
func matchInArray(arr []string, value string) bool {
|
||||
for _, v := range arr {
|
||||
if v == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchMapWithString returns true if the given key/value pairs exist in a given map.
|
||||
func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool {
|
||||
for k, v := range toCheck {
|
||||
// Check if key exists.
|
||||
if canonicalKey {
|
||||
k = http.CanonicalHeaderKey(k)
|
||||
}
|
||||
if values := toMatch[k]; values == nil {
|
||||
return false
|
||||
} else if v != "" {
|
||||
// If value was defined as an empty string we only check that the
|
||||
// key exists. Otherwise we also check for equality.
|
||||
valueExists := false
|
||||
for _, value := range values {
|
||||
if v == value {
|
||||
valueExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valueExists {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// matchMapWithRegex returns true if the given key/value pairs exist in a given map compiled against
|
||||
// the given regex
|
||||
func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]string, canonicalKey bool) bool {
|
||||
for k, v := range toCheck {
|
||||
// Check if key exists.
|
||||
if canonicalKey {
|
||||
k = http.CanonicalHeaderKey(k)
|
||||
}
|
||||
if values := toMatch[k]; values == nil {
|
||||
return false
|
||||
} else if v != nil {
|
||||
// If value was defined as an empty string we only check that the
|
||||
// key exists. Otherwise we also check for equality.
|
||||
valueExists := false
|
||||
for _, value := range values {
|
||||
if v.MatchString(value) {
|
||||
valueExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !valueExists {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// methodNotAllowed replies to the request with an HTTP status code 405.
|
||||
func methodNotAllowed(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
// methodNotAllowedHandler returns a simple request handler
|
||||
// that replies to each request with a status code 405.
|
||||
func methodNotAllowedHandler() http.Handler { return http.HandlerFunc(methodNotAllowed) }
|
332
vendor/github.com/gorilla/mux/regexp.go
generated
vendored
332
vendor/github.com/gorilla/mux/regexp.go
generated
vendored
@ -1,332 +0,0 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type routeRegexpOptions struct {
|
||||
strictSlash bool
|
||||
useEncodedPath bool
|
||||
}
|
||||
|
||||
type regexpType int
|
||||
|
||||
const (
|
||||
regexpTypePath regexpType = 0
|
||||
regexpTypeHost regexpType = 1
|
||||
regexpTypePrefix regexpType = 2
|
||||
regexpTypeQuery regexpType = 3
|
||||
)
|
||||
|
||||
// newRouteRegexp parses a route template and returns a routeRegexp,
|
||||
// used to match a host, a path or a query string.
|
||||
//
|
||||
// It will extract named variables, assemble a regexp to be matched, create
|
||||
// a "reverse" template to build URLs and compile regexps to validate variable
|
||||
// values used in URL building.
|
||||
//
|
||||
// Previously we accepted only Python-like identifiers for variable
|
||||
// names ([a-zA-Z_][a-zA-Z0-9_]*), but currently the only restriction is that
|
||||
// name and pattern can't be empty, and names can't contain a colon.
|
||||
func newRouteRegexp(tpl string, typ regexpType, options routeRegexpOptions) (*routeRegexp, error) {
|
||||
// Check if it is well-formed.
|
||||
idxs, errBraces := braceIndices(tpl)
|
||||
if errBraces != nil {
|
||||
return nil, errBraces
|
||||
}
|
||||
// Backup the original.
|
||||
template := tpl
|
||||
// Now let's parse it.
|
||||
defaultPattern := "[^/]+"
|
||||
if typ == regexpTypeQuery {
|
||||
defaultPattern = ".*"
|
||||
} else if typ == regexpTypeHost {
|
||||
defaultPattern = "[^.]+"
|
||||
}
|
||||
// Only match strict slash if not matching
|
||||
if typ != regexpTypePath {
|
||||
options.strictSlash = false
|
||||
}
|
||||
// Set a flag for strictSlash.
|
||||
endSlash := false
|
||||
if options.strictSlash && strings.HasSuffix(tpl, "/") {
|
||||
tpl = tpl[:len(tpl)-1]
|
||||
endSlash = true
|
||||
}
|
||||
varsN := make([]string, len(idxs)/2)
|
||||
varsR := make([]*regexp.Regexp, len(idxs)/2)
|
||||
pattern := bytes.NewBufferString("")
|
||||
pattern.WriteByte('^')
|
||||
reverse := bytes.NewBufferString("")
|
||||
var end int
|
||||
var err error
|
||||
for i := 0; i < len(idxs); i += 2 {
|
||||
// Set all values we are interested in.
|
||||
raw := tpl[end:idxs[i]]
|
||||
end = idxs[i+1]
|
||||
parts := strings.SplitN(tpl[idxs[i]+1:end-1], ":", 2)
|
||||
name := parts[0]
|
||||
patt := defaultPattern
|
||||
if len(parts) == 2 {
|
||||
patt = parts[1]
|
||||
}
|
||||
// Name or pattern can't be empty.
|
||||
if name == "" || patt == "" {
|
||||
return nil, fmt.Errorf("mux: missing name or pattern in %q",
|
||||
tpl[idxs[i]:end])
|
||||
}
|
||||
// Build the regexp pattern.
|
||||
fmt.Fprintf(pattern, "%s(?P<%s>%s)", regexp.QuoteMeta(raw), varGroupName(i/2), patt)
|
||||
|
||||
// Build the reverse template.
|
||||
fmt.Fprintf(reverse, "%s%%s", raw)
|
||||
|
||||
// Append variable name and compiled pattern.
|
||||
varsN[i/2] = name
|
||||
varsR[i/2], err = regexp.Compile(fmt.Sprintf("^%s$", patt))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Add the remaining.
|
||||
raw := tpl[end:]
|
||||
pattern.WriteString(regexp.QuoteMeta(raw))
|
||||
if options.strictSlash {
|
||||
pattern.WriteString("[/]?")
|
||||
}
|
||||
if typ == regexpTypeQuery {
|
||||
// Add the default pattern if the query value is empty
|
||||
if queryVal := strings.SplitN(template, "=", 2)[1]; queryVal == "" {
|
||||
pattern.WriteString(defaultPattern)
|
||||
}
|
||||
}
|
||||
if typ != regexpTypePrefix {
|
||||
pattern.WriteByte('$')
|
||||
}
|
||||
reverse.WriteString(raw)
|
||||
if endSlash {
|
||||
reverse.WriteByte('/')
|
||||
}
|
||||
// Compile full regexp.
|
||||
reg, errCompile := regexp.Compile(pattern.String())
|
||||
if errCompile != nil {
|
||||
return nil, errCompile
|
||||
}
|
||||
|
||||
// Check for capturing groups which used to work in older versions
|
||||
if reg.NumSubexp() != len(idxs)/2 {
|
||||
panic(fmt.Sprintf("route %s contains capture groups in its regexp. ", template) +
|
||||
"Only non-capturing groups are accepted: e.g. (?:pattern) instead of (pattern)")
|
||||
}
|
||||
|
||||
// Done!
|
||||
return &routeRegexp{
|
||||
template: template,
|
||||
regexpType: typ,
|
||||
options: options,
|
||||
regexp: reg,
|
||||
reverse: reverse.String(),
|
||||
varsN: varsN,
|
||||
varsR: varsR,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// routeRegexp stores a regexp to match a host or path and information to
|
||||
// collect and validate route variables.
|
||||
type routeRegexp struct {
|
||||
// The unmodified template.
|
||||
template string
|
||||
// The type of match
|
||||
regexpType regexpType
|
||||
// Options for matching
|
||||
options routeRegexpOptions
|
||||
// Expanded regexp.
|
||||
regexp *regexp.Regexp
|
||||
// Reverse template.
|
||||
reverse string
|
||||
// Variable names.
|
||||
varsN []string
|
||||
// Variable regexps (validators).
|
||||
varsR []*regexp.Regexp
|
||||
}
|
||||
|
||||
// Match matches the regexp against the URL host or path.
|
||||
func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
|
||||
if r.regexpType != regexpTypeHost {
|
||||
if r.regexpType == regexpTypeQuery {
|
||||
return r.matchQueryString(req)
|
||||
}
|
||||
path := req.URL.Path
|
||||
if r.options.useEncodedPath {
|
||||
path = req.URL.EscapedPath()
|
||||
}
|
||||
return r.regexp.MatchString(path)
|
||||
}
|
||||
|
||||
return r.regexp.MatchString(getHost(req))
|
||||
}
|
||||
|
||||
// url builds a URL part using the given values.
|
||||
func (r *routeRegexp) url(values map[string]string) (string, error) {
|
||||
urlValues := make([]interface{}, len(r.varsN))
|
||||
for k, v := range r.varsN {
|
||||
value, ok := values[v]
|
||||
if !ok {
|
||||
return "", fmt.Errorf("mux: missing route variable %q", v)
|
||||
}
|
||||
if r.regexpType == regexpTypeQuery {
|
||||
value = url.QueryEscape(value)
|
||||
}
|
||||
urlValues[k] = value
|
||||
}
|
||||
rv := fmt.Sprintf(r.reverse, urlValues...)
|
||||
if !r.regexp.MatchString(rv) {
|
||||
// The URL is checked against the full regexp, instead of checking
|
||||
// individual variables. This is faster but to provide a good error
|
||||
// message, we check individual regexps if the URL doesn't match.
|
||||
for k, v := range r.varsN {
|
||||
if !r.varsR[k].MatchString(values[v]) {
|
||||
return "", fmt.Errorf(
|
||||
"mux: variable %q doesn't match, expected %q", values[v],
|
||||
r.varsR[k].String())
|
||||
}
|
||||
}
|
||||
}
|
||||
return rv, nil
|
||||
}
|
||||
|
||||
// getURLQuery returns a single query parameter from a request URL.
|
||||
// For a URL with foo=bar&baz=ding, we return only the relevant key
|
||||
// value pair for the routeRegexp.
|
||||
func (r *routeRegexp) getURLQuery(req *http.Request) string {
|
||||
if r.regexpType != regexpTypeQuery {
|
||||
return ""
|
||||
}
|
||||
templateKey := strings.SplitN(r.template, "=", 2)[0]
|
||||
for key, vals := range req.URL.Query() {
|
||||
if key == templateKey && len(vals) > 0 {
|
||||
return key + "=" + vals[0]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *routeRegexp) matchQueryString(req *http.Request) bool {
|
||||
return r.regexp.MatchString(r.getURLQuery(req))
|
||||
}
|
||||
|
||||
// braceIndices returns the first level curly brace indices from a string.
|
||||
// It returns an error in case of unbalanced braces.
|
||||
func braceIndices(s string) ([]int, error) {
|
||||
var level, idx int
|
||||
var idxs []int
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '{':
|
||||
if level++; level == 1 {
|
||||
idx = i
|
||||
}
|
||||
case '}':
|
||||
if level--; level == 0 {
|
||||
idxs = append(idxs, idx, i+1)
|
||||
} else if level < 0 {
|
||||
return nil, fmt.Errorf("mux: unbalanced braces in %q", s)
|
||||
}
|
||||
}
|
||||
}
|
||||
if level != 0 {
|
||||
return nil, fmt.Errorf("mux: unbalanced braces in %q", s)
|
||||
}
|
||||
return idxs, nil
|
||||
}
|
||||
|
||||
// varGroupName builds a capturing group name for the indexed variable.
|
||||
func varGroupName(idx int) string {
|
||||
return "v" + strconv.Itoa(idx)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// routeRegexpGroup
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// routeRegexpGroup groups the route matchers that carry variables.
|
||||
type routeRegexpGroup struct {
|
||||
host *routeRegexp
|
||||
path *routeRegexp
|
||||
queries []*routeRegexp
|
||||
}
|
||||
|
||||
// setMatch extracts the variables from the URL once a route matches.
|
||||
func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) {
|
||||
// Store host variables.
|
||||
if v.host != nil {
|
||||
host := getHost(req)
|
||||
matches := v.host.regexp.FindStringSubmatchIndex(host)
|
||||
if len(matches) > 0 {
|
||||
extractVars(host, matches, v.host.varsN, m.Vars)
|
||||
}
|
||||
}
|
||||
path := req.URL.Path
|
||||
if r.useEncodedPath {
|
||||
path = req.URL.EscapedPath()
|
||||
}
|
||||
// Store path variables.
|
||||
if v.path != nil {
|
||||
matches := v.path.regexp.FindStringSubmatchIndex(path)
|
||||
if len(matches) > 0 {
|
||||
extractVars(path, matches, v.path.varsN, m.Vars)
|
||||
// Check if we should redirect.
|
||||
if v.path.options.strictSlash {
|
||||
p1 := strings.HasSuffix(path, "/")
|
||||
p2 := strings.HasSuffix(v.path.template, "/")
|
||||
if p1 != p2 {
|
||||
u, _ := url.Parse(req.URL.String())
|
||||
if p1 {
|
||||
u.Path = u.Path[:len(u.Path)-1]
|
||||
} else {
|
||||
u.Path += "/"
|
||||
}
|
||||
m.Handler = http.RedirectHandler(u.String(), 301)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Store query string variables.
|
||||
for _, q := range v.queries {
|
||||
queryURL := q.getURLQuery(req)
|
||||
matches := q.regexp.FindStringSubmatchIndex(queryURL)
|
||||
if len(matches) > 0 {
|
||||
extractVars(queryURL, matches, q.varsN, m.Vars)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getHost tries its best to return the request host.
|
||||
func getHost(r *http.Request) string {
|
||||
if r.URL.IsAbs() {
|
||||
return r.URL.Host
|
||||
}
|
||||
host := r.Host
|
||||
// Slice off any port information.
|
||||
if i := strings.Index(host, ":"); i != -1 {
|
||||
host = host[:i]
|
||||
}
|
||||
return host
|
||||
|
||||
}
|
||||
|
||||
func extractVars(input string, matches []int, names []string, output map[string]string) {
|
||||
for i, name := range names {
|
||||
output[name] = input[matches[2*i+2]:matches[2*i+3]]
|
||||
}
|
||||
}
|
763
vendor/github.com/gorilla/mux/route.go
generated
vendored
763
vendor/github.com/gorilla/mux/route.go
generated
vendored
@ -1,763 +0,0 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mux
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Route stores information to match a request and build URLs.
|
||||
type Route struct {
|
||||
// Parent where the route was registered (a Router).
|
||||
parent parentRoute
|
||||
// Request handler for the route.
|
||||
handler http.Handler
|
||||
// List of matchers.
|
||||
matchers []matcher
|
||||
// Manager for the variables from host and path.
|
||||
regexp *routeRegexpGroup
|
||||
// If true, when the path pattern is "/path/", accessing "/path" will
|
||||
// redirect to the former and vice versa.
|
||||
strictSlash bool
|
||||
// If true, when the path pattern is "/path//to", accessing "/path//to"
|
||||
// will not redirect
|
||||
skipClean bool
|
||||
// If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to"
|
||||
useEncodedPath bool
|
||||
// The scheme used when building URLs.
|
||||
buildScheme string
|
||||
// If true, this route never matches: it is only used to build URLs.
|
||||
buildOnly bool
|
||||
// The name used to build URLs.
|
||||
name string
|
||||
// Error resulted from building a route.
|
||||
err error
|
||||
|
||||
buildVarsFunc BuildVarsFunc
|
||||
}
|
||||
|
||||
// SkipClean reports whether path cleaning is enabled for this route via
|
||||
// Router.SkipClean.
|
||||
func (r *Route) SkipClean() bool {
|
||||
return r.skipClean
|
||||
}
|
||||
|
||||
// Match matches the route against the request.
|
||||
func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
|
||||
if r.buildOnly || r.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var matchErr error
|
||||
|
||||
// Match everything.
|
||||
for _, m := range r.matchers {
|
||||
if matched := m.Match(req, match); !matched {
|
||||
if _, ok := m.(methodMatcher); ok {
|
||||
matchErr = ErrMethodMismatch
|
||||
continue
|
||||
}
|
||||
matchErr = nil
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if matchErr != nil {
|
||||
match.MatchErr = matchErr
|
||||
return false
|
||||
}
|
||||
|
||||
if match.MatchErr == ErrMethodMismatch {
|
||||
// We found a route which matches request method, clear MatchErr
|
||||
match.MatchErr = nil
|
||||
// Then override the mis-matched handler
|
||||
match.Handler = r.handler
|
||||
}
|
||||
|
||||
// Yay, we have a match. Let's collect some info about it.
|
||||
if match.Route == nil {
|
||||
match.Route = r
|
||||
}
|
||||
if match.Handler == nil {
|
||||
match.Handler = r.handler
|
||||
}
|
||||
if match.Vars == nil {
|
||||
match.Vars = make(map[string]string)
|
||||
}
|
||||
|
||||
// Set variables.
|
||||
if r.regexp != nil {
|
||||
r.regexp.setMatch(req, match, r)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Route attributes
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// GetError returns an error resulted from building the route, if any.
|
||||
func (r *Route) GetError() error {
|
||||
return r.err
|
||||
}
|
||||
|
||||
// BuildOnly sets the route to never match: it is only used to build URLs.
|
||||
func (r *Route) BuildOnly() *Route {
|
||||
r.buildOnly = true
|
||||
return r
|
||||
}
|
||||
|
||||
// Handler --------------------------------------------------------------------
|
||||
|
||||
// Handler sets a handler for the route.
|
||||
func (r *Route) Handler(handler http.Handler) *Route {
|
||||
if r.err == nil {
|
||||
r.handler = handler
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// HandlerFunc sets a handler function for the route.
|
||||
func (r *Route) HandlerFunc(f func(http.ResponseWriter, *http.Request)) *Route {
|
||||
return r.Handler(http.HandlerFunc(f))
|
||||
}
|
||||
|
||||
// GetHandler returns the handler for the route, if any.
|
||||
func (r *Route) GetHandler() http.Handler {
|
||||
return r.handler
|
||||
}
|
||||
|
||||
// Name -----------------------------------------------------------------------
|
||||
|
||||
// Name sets the name for the route, used to build URLs.
|
||||
// If the name was registered already it will be overwritten.
|
||||
func (r *Route) Name(name string) *Route {
|
||||
if r.name != "" {
|
||||
r.err = fmt.Errorf("mux: route already has name %q, can't set %q",
|
||||
r.name, name)
|
||||
}
|
||||
if r.err == nil {
|
||||
r.name = name
|
||||
r.getNamedRoutes()[name] = r
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// GetName returns the name for the route, if any.
|
||||
func (r *Route) GetName() string {
|
||||
return r.name
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Matchers
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// matcher types try to match a request.
|
||||
type matcher interface {
|
||||
Match(*http.Request, *RouteMatch) bool
|
||||
}
|
||||
|
||||
// addMatcher adds a matcher to the route.
|
||||
func (r *Route) addMatcher(m matcher) *Route {
|
||||
if r.err == nil {
|
||||
r.matchers = append(r.matchers, m)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// addRegexpMatcher adds a host or path matcher and builder to a route.
|
||||
func (r *Route) addRegexpMatcher(tpl string, typ regexpType) error {
|
||||
if r.err != nil {
|
||||
return r.err
|
||||
}
|
||||
r.regexp = r.getRegexpGroup()
|
||||
if typ == regexpTypePath || typ == regexpTypePrefix {
|
||||
if len(tpl) > 0 && tpl[0] != '/' {
|
||||
return fmt.Errorf("mux: path must start with a slash, got %q", tpl)
|
||||
}
|
||||
if r.regexp.path != nil {
|
||||
tpl = strings.TrimRight(r.regexp.path.template, "/") + tpl
|
||||
}
|
||||
}
|
||||
rr, err := newRouteRegexp(tpl, typ, routeRegexpOptions{
|
||||
strictSlash: r.strictSlash,
|
||||
useEncodedPath: r.useEncodedPath,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, q := range r.regexp.queries {
|
||||
if err = uniqueVars(rr.varsN, q.varsN); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if typ == regexpTypeHost {
|
||||
if r.regexp.path != nil {
|
||||
if err = uniqueVars(rr.varsN, r.regexp.path.varsN); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
r.regexp.host = rr
|
||||
} else {
|
||||
if r.regexp.host != nil {
|
||||
if err = uniqueVars(rr.varsN, r.regexp.host.varsN); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if typ == regexpTypeQuery {
|
||||
r.regexp.queries = append(r.regexp.queries, rr)
|
||||
} else {
|
||||
r.regexp.path = rr
|
||||
}
|
||||
}
|
||||
r.addMatcher(rr)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Headers --------------------------------------------------------------------
|
||||
|
||||
// headerMatcher matches the request against header values.
|
||||
type headerMatcher map[string]string
|
||||
|
||||
func (m headerMatcher) Match(r *http.Request, match *RouteMatch) bool {
|
||||
return matchMapWithString(m, r.Header, true)
|
||||
}
|
||||
|
||||
// Headers adds a matcher for request header values.
|
||||
// It accepts a sequence of key/value pairs to be matched. For example:
|
||||
//
|
||||
// r := mux.NewRouter()
|
||||
// r.Headers("Content-Type", "application/json",
|
||||
// "X-Requested-With", "XMLHttpRequest")
|
||||
//
|
||||
// The above route will only match if both request header values match.
|
||||
// If the value is an empty string, it will match any value if the key is set.
|
||||
func (r *Route) Headers(pairs ...string) *Route {
|
||||
if r.err == nil {
|
||||
var headers map[string]string
|
||||
headers, r.err = mapFromPairsToString(pairs...)
|
||||
return r.addMatcher(headerMatcher(headers))
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// headerRegexMatcher matches the request against the route given a regex for the header
|
||||
type headerRegexMatcher map[string]*regexp.Regexp
|
||||
|
||||
func (m headerRegexMatcher) Match(r *http.Request, match *RouteMatch) bool {
|
||||
return matchMapWithRegex(m, r.Header, true)
|
||||
}
|
||||
|
||||
// HeadersRegexp accepts a sequence of key/value pairs, where the value has regex
|
||||
// support. For example:
|
||||
//
|
||||
// r := mux.NewRouter()
|
||||
// r.HeadersRegexp("Content-Type", "application/(text|json)",
|
||||
// "X-Requested-With", "XMLHttpRequest")
|
||||
//
|
||||
// The above route will only match if both the request header matches both regular expressions.
|
||||
// If the value is an empty string, it will match any value if the key is set.
|
||||
// Use the start and end of string anchors (^ and $) to match an exact value.
|
||||
func (r *Route) HeadersRegexp(pairs ...string) *Route {
|
||||
if r.err == nil {
|
||||
var headers map[string]*regexp.Regexp
|
||||
headers, r.err = mapFromPairsToRegex(pairs...)
|
||||
return r.addMatcher(headerRegexMatcher(headers))
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Host -----------------------------------------------------------------------
|
||||
|
||||
// Host adds a matcher for the URL host.
|
||||
// It accepts a template with zero or more URL variables enclosed by {}.
|
||||
// Variables can define an optional regexp pattern to be matched:
|
||||
//
|
||||
// - {name} matches anything until the next dot.
|
||||
//
|
||||
// - {name:pattern} matches the given regexp pattern.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// r := mux.NewRouter()
|
||||
// r.Host("www.example.com")
|
||||
// r.Host("{subdomain}.domain.com")
|
||||
// r.Host("{subdomain:[a-z]+}.domain.com")
|
||||
//
|
||||
// Variable names must be unique in a given route. They can be retrieved
|
||||
// calling mux.Vars(request).
|
||||
func (r *Route) Host(tpl string) *Route {
|
||||
r.err = r.addRegexpMatcher(tpl, regexpTypeHost)
|
||||
return r
|
||||
}
|
||||
|
||||
// MatcherFunc ----------------------------------------------------------------
|
||||
|
||||
// MatcherFunc is the function signature used by custom matchers.
|
||||
type MatcherFunc func(*http.Request, *RouteMatch) bool
|
||||
|
||||
// Match returns the match for a given request.
|
||||
func (m MatcherFunc) Match(r *http.Request, match *RouteMatch) bool {
|
||||
return m(r, match)
|
||||
}
|
||||
|
||||
// MatcherFunc adds a custom function to be used as request matcher.
|
||||
func (r *Route) MatcherFunc(f MatcherFunc) *Route {
|
||||
return r.addMatcher(f)
|
||||
}
|
||||
|
||||
// Methods --------------------------------------------------------------------
|
||||
|
||||
// methodMatcher matches the request against HTTP methods.
|
||||
type methodMatcher []string
|
||||
|
||||
func (m methodMatcher) Match(r *http.Request, match *RouteMatch) bool {
|
||||
return matchInArray(m, r.Method)
|
||||
}
|
||||
|
||||
// Methods adds a matcher for HTTP methods.
|
||||
// It accepts a sequence of one or more methods to be matched, e.g.:
|
||||
// "GET", "POST", "PUT".
|
||||
func (r *Route) Methods(methods ...string) *Route {
|
||||
for k, v := range methods {
|
||||
methods[k] = strings.ToUpper(v)
|
||||
}
|
||||
return r.addMatcher(methodMatcher(methods))
|
||||
}
|
||||
|
||||
// Path -----------------------------------------------------------------------
|
||||
|
||||
// Path adds a matcher for the URL path.
|
||||
// It accepts a template with zero or more URL variables enclosed by {}. The
|
||||
// template must start with a "/".
|
||||
// Variables can define an optional regexp pattern to be matched:
|
||||
//
|
||||
// - {name} matches anything until the next slash.
|
||||
//
|
||||
// - {name:pattern} matches the given regexp pattern.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// r := mux.NewRouter()
|
||||
// r.Path("/products/").Handler(ProductsHandler)
|
||||
// r.Path("/products/{key}").Handler(ProductsHandler)
|
||||
// r.Path("/articles/{category}/{id:[0-9]+}").
|
||||
// Handler(ArticleHandler)
|
||||
//
|
||||
// Variable names must be unique in a given route. They can be retrieved
|
||||
// calling mux.Vars(request).
|
||||
func (r *Route) Path(tpl string) *Route {
|
||||
r.err = r.addRegexpMatcher(tpl, regexpTypePath)
|
||||
return r
|
||||
}
|
||||
|
||||
// PathPrefix -----------------------------------------------------------------
|
||||
|
||||
// PathPrefix adds a matcher for the URL path prefix. This matches if the given
|
||||
// template is a prefix of the full URL path. See Route.Path() for details on
|
||||
// the tpl argument.
|
||||
//
|
||||
// Note that it does not treat slashes specially ("/foobar/" will be matched by
|
||||
// the prefix "/foo") so you may want to use a trailing slash here.
|
||||
//
|
||||
// Also note that the setting of Router.StrictSlash() has no effect on routes
|
||||
// with a PathPrefix matcher.
|
||||
func (r *Route) PathPrefix(tpl string) *Route {
|
||||
r.err = r.addRegexpMatcher(tpl, regexpTypePrefix)
|
||||
return r
|
||||
}
|
||||
|
||||
// Query ----------------------------------------------------------------------
|
||||
|
||||
// Queries adds a matcher for URL query values.
|
||||
// It accepts a sequence of key/value pairs. Values may define variables.
|
||||
// For example:
|
||||
//
|
||||
// r := mux.NewRouter()
|
||||
// r.Queries("foo", "bar", "id", "{id:[0-9]+}")
|
||||
//
|
||||
// The above route will only match if the URL contains the defined queries
|
||||
// values, e.g.: ?foo=bar&id=42.
|
||||
//
|
||||
// It the value is an empty string, it will match any value if the key is set.
|
||||
//
|
||||
// Variables can define an optional regexp pattern to be matched:
|
||||
//
|
||||
// - {name} matches anything until the next slash.
|
||||
//
|
||||
// - {name:pattern} matches the given regexp pattern.
|
||||
func (r *Route) Queries(pairs ...string) *Route {
|
||||
length := len(pairs)
|
||||
if length%2 != 0 {
|
||||
r.err = fmt.Errorf(
|
||||
"mux: number of parameters must be multiple of 2, got %v", pairs)
|
||||
return nil
|
||||
}
|
||||
for i := 0; i < length; i += 2 {
|
||||
if r.err = r.addRegexpMatcher(pairs[i]+"="+pairs[i+1], regexpTypeQuery); r.err != nil {
|
||||
return r
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// Schemes --------------------------------------------------------------------
|
||||
|
||||
// schemeMatcher matches the request against URL schemes.
|
||||
type schemeMatcher []string
|
||||
|
||||
func (m schemeMatcher) Match(r *http.Request, match *RouteMatch) bool {
|
||||
return matchInArray(m, r.URL.Scheme)
|
||||
}
|
||||
|
||||
// Schemes adds a matcher for URL schemes.
|
||||
// It accepts a sequence of schemes to be matched, e.g.: "http", "https".
|
||||
func (r *Route) Schemes(schemes ...string) *Route {
|
||||
for k, v := range schemes {
|
||||
schemes[k] = strings.ToLower(v)
|
||||
}
|
||||
if r.buildScheme == "" && len(schemes) > 0 {
|
||||
r.buildScheme = schemes[0]
|
||||
}
|
||||
return r.addMatcher(schemeMatcher(schemes))
|
||||
}
|
||||
|
||||
// BuildVarsFunc --------------------------------------------------------------
|
||||
|
||||
// BuildVarsFunc is the function signature used by custom build variable
|
||||
// functions (which can modify route variables before a route's URL is built).
|
||||
type BuildVarsFunc func(map[string]string) map[string]string
|
||||
|
||||
// BuildVarsFunc adds a custom function to be used to modify build variables
|
||||
// before a route's URL is built.
|
||||
func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route {
|
||||
r.buildVarsFunc = f
|
||||
return r
|
||||
}
|
||||
|
||||
// Subrouter ------------------------------------------------------------------
|
||||
|
||||
// Subrouter creates a subrouter for the route.
|
||||
//
|
||||
// It will test the inner routes only if the parent route matched. For example:
|
||||
//
|
||||
// r := mux.NewRouter()
|
||||
// s := r.Host("www.example.com").Subrouter()
|
||||
// s.HandleFunc("/products/", ProductsHandler)
|
||||
// s.HandleFunc("/products/{key}", ProductHandler)
|
||||
// s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler)
|
||||
//
|
||||
// Here, the routes registered in the subrouter won't be tested if the host
|
||||
// doesn't match.
|
||||
func (r *Route) Subrouter() *Router {
|
||||
router := &Router{parent: r, strictSlash: r.strictSlash}
|
||||
r.addMatcher(router)
|
||||
return router
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// URL building
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// URL builds a URL for the route.
|
||||
//
|
||||
// It accepts a sequence of key/value pairs for the route variables. For
|
||||
// example, given this route:
|
||||
//
|
||||
// r := mux.NewRouter()
|
||||
// r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
|
||||
// Name("article")
|
||||
//
|
||||
// ...a URL for it can be built using:
|
||||
//
|
||||
// url, err := r.Get("article").URL("category", "technology", "id", "42")
|
||||
//
|
||||
// ...which will return an url.URL with the following path:
|
||||
//
|
||||
// "/articles/technology/42"
|
||||
//
|
||||
// This also works for host variables:
|
||||
//
|
||||
// r := mux.NewRouter()
|
||||
// r.Host("{subdomain}.domain.com").
|
||||
// HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
|
||||
// Name("article")
|
||||
//
|
||||
// // url.String() will be "http://news.domain.com/articles/technology/42"
|
||||
// url, err := r.Get("article").URL("subdomain", "news",
|
||||
// "category", "technology",
|
||||
// "id", "42")
|
||||
//
|
||||
// All variables defined in the route are required, and their values must
|
||||
// conform to the corresponding patterns.
|
||||
func (r *Route) URL(pairs ...string) (*url.URL, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
if r.regexp == nil {
|
||||
return nil, errors.New("mux: route doesn't have a host or path")
|
||||
}
|
||||
values, err := r.prepareVars(pairs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var scheme, host, path string
|
||||
queries := make([]string, 0, len(r.regexp.queries))
|
||||
if r.regexp.host != nil {
|
||||
if host, err = r.regexp.host.url(values); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scheme = "http"
|
||||
if s := r.getBuildScheme(); s != "" {
|
||||
scheme = s
|
||||
}
|
||||
}
|
||||
if r.regexp.path != nil {
|
||||
if path, err = r.regexp.path.url(values); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for _, q := range r.regexp.queries {
|
||||
var query string
|
||||
if query, err = q.url(values); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queries = append(queries, query)
|
||||
}
|
||||
return &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
Path: path,
|
||||
RawQuery: strings.Join(queries, "&"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// URLHost builds the host part of the URL for a route. See Route.URL().
|
||||
//
|
||||
// The route must have a host defined.
|
||||
func (r *Route) URLHost(pairs ...string) (*url.URL, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.host == nil {
|
||||
return nil, errors.New("mux: route doesn't have a host")
|
||||
}
|
||||
values, err := r.prepareVars(pairs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
host, err := r.regexp.host.url(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
u := &url.URL{
|
||||
Scheme: "http",
|
||||
Host: host,
|
||||
}
|
||||
if s := r.getBuildScheme(); s != "" {
|
||||
u.Scheme = s
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// URLPath builds the path part of the URL for a route. See Route.URL().
|
||||
//
|
||||
// The route must have a path defined.
|
||||
func (r *Route) URLPath(pairs ...string) (*url.URL, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.path == nil {
|
||||
return nil, errors.New("mux: route doesn't have a path")
|
||||
}
|
||||
values, err := r.prepareVars(pairs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
path, err := r.regexp.path.url(values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &url.URL{
|
||||
Path: path,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPathTemplate returns the template used to build the
|
||||
// route match.
|
||||
// This is useful for building simple REST API documentation and for instrumentation
|
||||
// against third-party services.
|
||||
// An error will be returned if the route does not define a path.
|
||||
func (r *Route) GetPathTemplate() (string, error) {
|
||||
if r.err != nil {
|
||||
return "", r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.path == nil {
|
||||
return "", errors.New("mux: route doesn't have a path")
|
||||
}
|
||||
return r.regexp.path.template, nil
|
||||
}
|
||||
|
||||
// GetPathRegexp returns the expanded regular expression used to match route path.
|
||||
// This is useful for building simple REST API documentation and for instrumentation
|
||||
// against third-party services.
|
||||
// An error will be returned if the route does not define a path.
|
||||
func (r *Route) GetPathRegexp() (string, error) {
|
||||
if r.err != nil {
|
||||
return "", r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.path == nil {
|
||||
return "", errors.New("mux: route does not have a path")
|
||||
}
|
||||
return r.regexp.path.regexp.String(), nil
|
||||
}
|
||||
|
||||
// GetQueriesRegexp returns the expanded regular expressions used to match the
|
||||
// route queries.
|
||||
// This is useful for building simple REST API documentation and for instrumentation
|
||||
// against third-party services.
|
||||
// An error will be returned if the route does not have queries.
|
||||
func (r *Route) GetQueriesRegexp() ([]string, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.queries == nil {
|
||||
return nil, errors.New("mux: route doesn't have queries")
|
||||
}
|
||||
var queries []string
|
||||
for _, query := range r.regexp.queries {
|
||||
queries = append(queries, query.regexp.String())
|
||||
}
|
||||
return queries, nil
|
||||
}
|
||||
|
||||
// GetQueriesTemplates returns the templates used to build the
|
||||
// query matching.
|
||||
// This is useful for building simple REST API documentation and for instrumentation
|
||||
// against third-party services.
|
||||
// An error will be returned if the route does not define queries.
|
||||
func (r *Route) GetQueriesTemplates() ([]string, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.queries == nil {
|
||||
return nil, errors.New("mux: route doesn't have queries")
|
||||
}
|
||||
var queries []string
|
||||
for _, query := range r.regexp.queries {
|
||||
queries = append(queries, query.template)
|
||||
}
|
||||
return queries, nil
|
||||
}
|
||||
|
||||
// GetMethods returns the methods the route matches against
|
||||
// This is useful for building simple REST API documentation and for instrumentation
|
||||
// against third-party services.
|
||||
// An error will be returned if route does not have methods.
|
||||
func (r *Route) GetMethods() ([]string, error) {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
for _, m := range r.matchers {
|
||||
if methods, ok := m.(methodMatcher); ok {
|
||||
return []string(methods), nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("mux: route doesn't have methods")
|
||||
}
|
||||
|
||||
// GetHostTemplate returns the template used to build the
|
||||
// route match.
|
||||
// This is useful for building simple REST API documentation and for instrumentation
|
||||
// against third-party services.
|
||||
// An error will be returned if the route does not define a host.
|
||||
func (r *Route) GetHostTemplate() (string, error) {
|
||||
if r.err != nil {
|
||||
return "", r.err
|
||||
}
|
||||
if r.regexp == nil || r.regexp.host == nil {
|
||||
return "", errors.New("mux: route doesn't have a host")
|
||||
}
|
||||
return r.regexp.host.template, nil
|
||||
}
|
||||
|
||||
// prepareVars converts the route variable pairs into a map. If the route has a
|
||||
// BuildVarsFunc, it is invoked.
|
||||
func (r *Route) prepareVars(pairs ...string) (map[string]string, error) {
|
||||
m, err := mapFromPairsToString(pairs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.buildVars(m), nil
|
||||
}
|
||||
|
||||
func (r *Route) buildVars(m map[string]string) map[string]string {
|
||||
if r.parent != nil {
|
||||
m = r.parent.buildVars(m)
|
||||
}
|
||||
if r.buildVarsFunc != nil {
|
||||
m = r.buildVarsFunc(m)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// parentRoute
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// parentRoute allows routes to know about parent host and path definitions.
|
||||
type parentRoute interface {
|
||||
getBuildScheme() string
|
||||
getNamedRoutes() map[string]*Route
|
||||
getRegexpGroup() *routeRegexpGroup
|
||||
buildVars(map[string]string) map[string]string
|
||||
}
|
||||
|
||||
func (r *Route) getBuildScheme() string {
|
||||
if r.buildScheme != "" {
|
||||
return r.buildScheme
|
||||
}
|
||||
if r.parent != nil {
|
||||
return r.parent.getBuildScheme()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getNamedRoutes returns the map where named routes are registered.
|
||||
func (r *Route) getNamedRoutes() map[string]*Route {
|
||||
if r.parent == nil {
|
||||
// During tests router is not always set.
|
||||
r.parent = NewRouter()
|
||||
}
|
||||
return r.parent.getNamedRoutes()
|
||||
}
|
||||
|
||||
// getRegexpGroup returns regexp definitions from this route.
|
||||
func (r *Route) getRegexpGroup() *routeRegexpGroup {
|
||||
if r.regexp == nil {
|
||||
if r.parent == nil {
|
||||
// During tests router is not always set.
|
||||
r.parent = NewRouter()
|
||||
}
|
||||
regexp := r.parent.getRegexpGroup()
|
||||
if regexp == nil {
|
||||
r.regexp = new(routeRegexpGroup)
|
||||
} else {
|
||||
// Copy.
|
||||
r.regexp = &routeRegexpGroup{
|
||||
host: regexp.host,
|
||||
path: regexp.path,
|
||||
queries: regexp.queries,
|
||||
}
|
||||
}
|
||||
}
|
||||
return r.regexp
|
||||
}
|
19
vendor/github.com/gorilla/mux/test_helpers.go
generated
vendored
19
vendor/github.com/gorilla/mux/test_helpers.go
generated
vendored
@ -1,19 +0,0 @@
|
||||
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package mux
|
||||
|
||||
import "net/http"
|
||||
|
||||
// SetURLVars sets the URL variables for the given request, to be accessed via
|
||||
// mux.Vars for testing route behaviour. Arguments are not modified, a shallow
|
||||
// copy is returned.
|
||||
//
|
||||
// This API should only be used for testing purposes; it provides a way to
|
||||
// inject variables into the request context. Alternatively, URL variables
|
||||
// can be set by making a route that captures the required variables,
|
||||
// starting a server and sending the request to that server.
|
||||
func SetURLVars(r *http.Request, val map[string]string) *http.Request {
|
||||
return setVars(r, val)
|
||||
}
|
14
vendor/github.com/mattn/go-sqlite3/.gitignore
generated
vendored
14
vendor/github.com/mattn/go-sqlite3/.gitignore
generated
vendored
@ -1,14 +0,0 @@
|
||||
*.db
|
||||
*.exe
|
||||
*.dll
|
||||
*.o
|
||||
|
||||
# VSCode
|
||||
.vscode
|
||||
|
||||
# Exclude from upgrade
|
||||
upgrade/*.c
|
||||
upgrade/*.h
|
||||
|
||||
# Exclude upgrade binary
|
||||
upgrade/upgrade
|
41
vendor/github.com/mattn/go-sqlite3/.travis.yml
generated
vendored
41
vendor/github.com/mattn/go-sqlite3/.travis.yml
generated
vendored
@ -1,41 +0,0 @@
|
||||
language: go
|
||||
|
||||
os:
|
||||
- linux
|
||||
- osx
|
||||
|
||||
addons:
|
||||
apt:
|
||||
update: true
|
||||
|
||||
env:
|
||||
matrix:
|
||||
- GOTAGS=
|
||||
- GOTAGS=libsqlite3
|
||||
- GOTAGS="sqlite_allow_uri_authority sqlite_app_armor sqlite_foreign_keys sqlite_fts5 sqlite_icu sqlite_introspect sqlite_json sqlite_secure_delete sqlite_see sqlite_stat4 sqlite_trace sqlite_userauth sqlite_vacuum_incr sqlite_vtable"
|
||||
- GOTAGS=sqlite_vacuum_full
|
||||
|
||||
go:
|
||||
- 1.9.x
|
||||
- 1.10.x
|
||||
|
||||
before_install:
|
||||
- |
|
||||
if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then
|
||||
brew update
|
||||
brew upgrade icu4c
|
||||
fi
|
||||
- |
|
||||
go get github.com/smartystreets/goconvey
|
||||
if [[ "${GOOS}" != "windows" ]]; then
|
||||
go get github.com/mattn/goveralls
|
||||
go get golang.org/x/tools/cmd/cover
|
||||
fi
|
||||
|
||||
script:
|
||||
- GOOS=$(go env GOOS) GOARCH=$(go env GOARCH) go build -v -tags "${GOTAGS}" .
|
||||
- |
|
||||
if [[ "${GOOS}" != "windows" ]]; then
|
||||
$HOME/gopath/bin/goveralls -repotoken 3qJVUE0iQwqnCbmNcDsjYu1nh4J4KIFXx
|
||||
go test -race -v . -tags "${GOTAGS}"
|
||||
fi
|
21
vendor/github.com/mattn/go-sqlite3/LICENSE
generated
vendored
21
vendor/github.com/mattn/go-sqlite3/LICENSE
generated
vendored
@ -1,21 +0,0 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Yasuhiro Matsumoto
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
518
vendor/github.com/mattn/go-sqlite3/README.md
generated
vendored
518
vendor/github.com/mattn/go-sqlite3/README.md
generated
vendored
@ -1,518 +0,0 @@
|
||||
go-sqlite3
|
||||
==========
|
||||
|
||||
[![GoDoc Reference](https://godoc.org/github.com/mattn/go-sqlite3?status.svg)](http://godoc.org/github.com/mattn/go-sqlite3)
|
||||
[![Build Status](https://travis-ci.org/mattn/go-sqlite3.svg?branch=master)](https://travis-ci.org/mattn/go-sqlite3)
|
||||
[![Coverage Status](https://coveralls.io/repos/mattn/go-sqlite3/badge.svg?branch=master)](https://coveralls.io/r/mattn/go-sqlite3?branch=master)
|
||||
[![Go Report Card](https://goreportcard.com/badge/github.com/mattn/go-sqlite3)](https://goreportcard.com/report/github.com/mattn/go-sqlite3)
|
||||
|
||||
# Description
|
||||
|
||||
sqlite3 driver conforming to the built-in database/sql interface
|
||||
|
||||
Supported Golang version:
|
||||
- 1.9.x
|
||||
- 1.10.x
|
||||
|
||||
[This package follows the official Golang Release Policy.](https://golang.org/doc/devel/release.html#policy)
|
||||
|
||||
### Overview
|
||||
|
||||
- [Installation](#installation)
|
||||
- [API Reference](#api-reference)
|
||||
- [Connection String](#connection-string)
|
||||
- [Features](#features)
|
||||
- [Compilation](#compilation)
|
||||
- [Android](#android)
|
||||
- [ARM](#arm)
|
||||
- [Cross Compile](#cross-compile)
|
||||
- [Google Cloud Platform](#google-cloud-platform)
|
||||
- [Linux](#linux)
|
||||
- [Alpine](#alpine)
|
||||
- [Fedora](#fedora)
|
||||
- [Ubuntu](#ubuntu)
|
||||
- [Mac OSX](#mac-osx)
|
||||
- [Windows](#windows)
|
||||
- [Errors](#errors)
|
||||
- [User Authentication](#user-authentication)
|
||||
- [Compile](#compile)
|
||||
- [Usage](#usage)
|
||||
- [Extensions](#extensions)
|
||||
- [Spatialite](#spatialite)
|
||||
- [FAQ](#faq)
|
||||
- [License](#license)
|
||||
|
||||
# Installation
|
||||
|
||||
This package can be installed with the go get command:
|
||||
|
||||
go get github.com/mattn/go-sqlite3
|
||||
|
||||
_go-sqlite3_ is *cgo* package.
|
||||
If you want to build your app using go-sqlite3, you need gcc.
|
||||
However, after you have built and installed _go-sqlite3_ with `go install github.com/mattn/go-sqlite3` (which requires gcc), you can build your app without relying on gcc in future.
|
||||
|
||||
***Important: because this is a `CGO` enabled package you are required to set the environment variable `CGO_ENABLED=1` and have a `gcc` compile present within your path.***
|
||||
|
||||
# API Reference
|
||||
|
||||
API documentation can be found here: http://godoc.org/github.com/mattn/go-sqlite3
|
||||
|
||||
Examples can be found under the [examples](./_example) directory
|
||||
|
||||
# Connection String
|
||||
|
||||
When creating a new SQLite database or connection to an existing one, with the file name additional options can be given.
|
||||
This is also known as a DSN string. (Data Source Name).
|
||||
|
||||
Options are append after the filename of the SQLite database.
|
||||
The database filename and options are seperated by an `?` (Question Mark).
|
||||
|
||||
This also applies when using an in-memory database instead of a file.
|
||||
|
||||
Options can be given using the following format: `KEYWORD=VALUE` and multiple options can be combined with the `&` ampersand.
|
||||
|
||||
This library supports dsn options of SQLite itself and provides additional options.
|
||||
|
||||
Boolean values can be one of:
|
||||
* `0` `no` `false` `off`
|
||||
* `1` `yes` `true` `on`
|
||||
|
||||
| Name | Key | Value(s) | Description |
|
||||
|------|-----|----------|-------------|
|
||||
| UA - Create | `_auth` | - | Create User Authentication, for more information see [User Authentication](#user-authentication) |
|
||||
| UA - Username | `_auth_user` | `string` | Username for User Authentication, for more information see [User Authentication](#user-authentication) |
|
||||
| UA - Password | `_auth_pass` | `string` | Password for User Authentication, for more information see [User Authentication](#user-authentication) |
|
||||
| UA - Crypt | `_auth_crypt` | <ul><li>SHA1</li><li>SSHA1</li><li>SHA256</li><li>SSHA256</li><li>SHA384</li><li>SSHA384</li><li>SHA512</li><li>SSHA512</li></ul> | Password encoder to use for User Authentication, for more information see [User Authentication](#user-authentication) |
|
||||
| UA - Salt | `_auth_salt` | `string` | Salt to use if the configure password encoder requires a salt, for User Authentication, for more information see [User Authentication](#user-authentication) |
|
||||
| Auto Vacuum | `_auto_vacuum` \| `_vacuum` | <ul><li>`0` \| `none`</li><li>`1` \| `full`</li><li>`2` \| `incremental`</li></ul> | For more information see [PRAGMA auto_vacuum](https://www.sqlite.org/pragma.html#pragma_auto_vacuum) |
|
||||
| Busy Timeout | `_busy_timeout` \| `_timeout` | `int` | Specify value for sqlite3_busy_timeout. For more information see [PRAGMA busy_timeout](https://www.sqlite.org/pragma.html#pragma_busy_timeout) |
|
||||
| Case Sensitive LIKE | `_case_sensitive_like` \| `_cslike` | `boolean` | For more information see [PRAGMA case_sensitive_like](https://www.sqlite.org/pragma.html#pragma_case_sensitive_like) |
|
||||
| Defer Foreign Keys | `_defer_foreign_keys` \| `_defer_fk` | `boolean` | For more information see [PRAGMA defer_foreign_keys](https://www.sqlite.org/pragma.html#pragma_defer_foreign_keys) |
|
||||
| Foreign Keys | `_foreign_keys` \| `_fk` | `boolean` | For more information see [PRAGMA foreign_keys](https://www.sqlite.org/pragma.html#pragma_foreign_keys) |
|
||||
| Ignore CHECK Constraints | `_ignore_check_constraints` | `boolean` | For more information see [PRAGMA ignore_check_constraints](https://www.sqlite.org/pragma.html#pragma_ignore_check_constraints) |
|
||||
| Immutable | `immutable` | `boolean` | For more information see [Immutable](https://www.sqlite.org/c3ref/open.html) |
|
||||
| Journal Mode | `_journal_mode` \| `_journal` | <ul><li>DELETE</li><li>TRUNCATE</li><li>PERSIST</li><li>MEMORY</li><li>WAL</li><li>OFF</li></ul> | For more information see [PRAGMA journal_mode](https://www.sqlite.org/pragma.html#pragma_journal_mode) |
|
||||
| Locking Mode | `_locking_mode` \| `_locking` | <ul><li>NORMAL</li><li>EXCLUSIVE</li></ul> | For more information see [PRAGMA locking_mode](https://www.sqlite.org/pragma.html#pragma_locking_mode) |
|
||||
| Mode | `mode` | <ul><li>ro</li><li>rw</li><li>rwc</li><li>memory</li></ul> | Access Mode of the database. For more information see [SQLite Open](https://www.sqlite.org/c3ref/open.html) |
|
||||
| Mutex Locking | `_mutex` | <ul><li>no</li><li>full</li></ul> | Specify mutex mode. |
|
||||
| Query Only | `_query_only` | `boolean` | For more information see [PRAGMA query_only](https://www.sqlite.org/pragma.html#pragma_query_only) |
|
||||
| Recursive Triggers | `_recursive_triggers` \| `_rt` | `boolean` | For more information see [PRAGMA recursive_triggers](https://www.sqlite.org/pragma.html#pragma_recursive_triggers) |
|
||||
| Secure Delete | `_secure_delete` | `boolean` \| `FAST` | For more information see [PRAGMA secure_delete](https://www.sqlite.org/pragma.html#pragma_secure_delete) |
|
||||
| Shared-Cache Mode | `cache` | <ul><li>shared</li><li>private</li></ul> | Set cache mode for more information see [sqlite.org](https://www.sqlite.org/sharedcache.html) |
|
||||
| Synchronous | `_synchronous` \| `_sync` | <ul><li>0 \| OFF</li><li>1 \| NORMAL</li><li>2 \| FULL</li><li>3 \| EXTRA</li></ul> | For more information see [PRAGMA synchronous](https://www.sqlite.org/pragma.html#pragma_synchronous) |
|
||||
| Time Zone Location | `_loc` | auto | Specify location of time format. |
|
||||
| Transaction Lock | `_txlock` | <ul><li>immediate</li><li>deferred</li><li>exclusive</li></ul> | Specify locking behavior for transactions. |
|
||||
| Writable Schema | `_writable_schema` | `Boolean` | When this pragma is on, the SQLITE_MASTER tables in which database can be changed using ordinary UPDATE, INSERT, and DELETE statements. Warning: misuse of this pragma can easily result in a corrupt database file. |
|
||||
|
||||
## DSN Examples
|
||||
|
||||
```
|
||||
file:test.db?cache=shared&mode=memory
|
||||
```
|
||||
|
||||
# Features
|
||||
|
||||
This package allows additional configuration of features available within SQLite3 to be enabled or disabled by golang build constraints also known as build `tags`.
|
||||
|
||||
[Click here for more information about build tags / constraints.](https://golang.org/pkg/go/build/#hdr-Build_Constraints)
|
||||
|
||||
### Usage
|
||||
|
||||
If you wish to build this library with additional extensions / features.
|
||||
Use the following command.
|
||||
|
||||
```bash
|
||||
go build --tags "<FEATURE>"
|
||||
```
|
||||
|
||||
For available features see the extension list.
|
||||
When using multiple build tags, all the different tags should be space delimted.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
go build --tags "icu json1 fts5 secure_delete"
|
||||
```
|
||||
|
||||
### Feature / Extension List
|
||||
|
||||
| Extension | Build Tag | Description |
|
||||
|-----------|-----------|-------------|
|
||||
| Additional Statistics | sqlite_stat4 | This option adds additional logic to the ANALYZE command and to the query planner that can help SQLite to chose a better query plan under certain situations. The ANALYZE command is enhanced to collect histogram data from all columns of every index and store that data in the sqlite_stat4 table.<br><br>The query planner will then use the histogram data to help it make better index choices. The downside of this compile-time option is that it violates the query planner stability guarantee making it more difficult to ensure consistent performance in mass-produced applications.<br><br>SQLITE_ENABLE_STAT4 is an enhancement of SQLITE_ENABLE_STAT3. STAT3 only recorded histogram data for the left-most column of each index whereas the STAT4 enhancement records histogram data from all columns of each index.<br><br>The SQLITE_ENABLE_STAT3 compile-time option is a no-op and is ignored if the SQLITE_ENABLE_STAT4 compile-time option is used |
|
||||
| Allow URI Authority | sqlite_allow_uri_authority | URI filenames normally throws an error if the authority section is not either empty or "localhost".<br><br>However, if SQLite is compiled with the SQLITE_ALLOW_URI_AUTHORITY compile-time option, then the URI is converted into a Uniform Naming Convention (UNC) filename and passed down to the underlying operating system that way |
|
||||
| App Armor | sqlite_app_armor | When defined, this C-preprocessor macro activates extra code that attempts to detect misuse of the SQLite API, such as passing in NULL pointers to required parameters or using objects after they have been destroyed. <br><br>App Armor is not available under `Windows`. |
|
||||
| Disable Load Extensions | sqlite_omit_load_extension | Loading of external extensions is enabled by default.<br><br>To disable extension loading add the build tag `sqlite_omit_load_extension`. |
|
||||
| Foreign Keys | sqlite_foreign_keys | This macro determines whether enforcement of foreign key constraints is enabled or disabled by default for new database connections.<br><br>Each database connection can always turn enforcement of foreign key constraints on and off and run-time using the foreign_keys pragma.<br><br>Enforcement of foreign key constraints is normally off by default, but if this compile-time parameter is set to 1, enforcement of foreign key constraints will be on by default |
|
||||
| Full Auto Vacuum | sqlite_vacuum_full | Set the default auto vacuum to full |
|
||||
| Incremental Auto Vacuum | sqlite_vacuum_incr | Set the default auto vacuum to incremental |
|
||||
| Full Text Search Engine | sqlite_fts5 | When this option is defined in the amalgamation, versions 5 of the full-text search engine (fts5) is added to the build automatically |
|
||||
| International Components for Unicode | sqlite_icu | This option causes the International Components for Unicode or "ICU" extension to SQLite to be added to the build |
|
||||
| Introspect PRAGMAS | sqlite_introspect | This option adds some extra PRAGMA statements. <ul><li>PRAGMA function_list</li><li>PRAGMA module_list</li><li>PRAGMA pragma_list</li></ul> |
|
||||
| JSON SQL Functions | sqlite_json | When this option is defined in the amalgamation, the JSON SQL functions are added to the build automatically |
|
||||
| Secure Delete | sqlite_secure_delete | This compile-time option changes the default setting of the secure_delete pragma.<br><br>When this option is not used, secure_delete defaults to off. When this option is present, secure_delete defaults to on.<br><br>The secure_delete setting causes deleted content to be overwritten with zeros. There is a small performance penalty since additional I/O must occur.<br><br>On the other hand, secure_delete can prevent fragments of sensitive information from lingering in unused parts of the database file after it has been deleted. See the documentation on the secure_delete pragma for additional information |
|
||||
| Secure Delete (FAST) | sqlite_secure_delete_fast | For more information see [PRAGMA secure_delete](https://www.sqlite.org/pragma.html#pragma_secure_delete) |
|
||||
| Tracing / Debug | sqlite_trace | Activate trace functions |
|
||||
| User Authentication | sqlite_userauth | SQLite User Authentication see [User Authentication](#user-authentication) for more information. |
|
||||
|
||||
# Compilation
|
||||
|
||||
This package requires `CGO_ENABLED=1` ennvironment variable if not set by default, and the presence of the `gcc` compiler.
|
||||
|
||||
If you need to add additional CFLAGS or LDFLAGS to the build command, and do not want to modify this package. Then this can be achieved by using the `CGO_CFLAGS` and `CGO_LDFLAGS` environment variables.
|
||||
|
||||
## Android
|
||||
|
||||
This package can be compiled for android.
|
||||
Compile with:
|
||||
|
||||
```bash
|
||||
go build --tags "android"
|
||||
```
|
||||
|
||||
For more information see [#201](https://github.com/mattn/go-sqlite3/issues/201)
|
||||
|
||||
# ARM
|
||||
|
||||
To compile for `ARM` use the following environment.
|
||||
|
||||
```bash
|
||||
env CC=arm-linux-gnueabihf-gcc CXX=arm-linux-gnueabihf-g++ \
|
||||
CGO_ENABLED=1 GOOS=linux GOARCH=arm GOARM=7 \
|
||||
go build -v
|
||||
```
|
||||
|
||||
Additional information:
|
||||
- [#242](https://github.com/mattn/go-sqlite3/issues/242)
|
||||
- [#504](https://github.com/mattn/go-sqlite3/issues/504)
|
||||
|
||||
# Cross Compile
|
||||
|
||||
This library can be cross-compiled.
|
||||
|
||||
In some cases you are required to the `CC` environment variable with the cross compiler.
|
||||
|
||||
Additional information:
|
||||
- [#491](https://github.com/mattn/go-sqlite3/issues/491)
|
||||
- [#560](https://github.com/mattn/go-sqlite3/issues/560)
|
||||
|
||||
# Google Cloud Platform
|
||||
|
||||
Building on GCP is not possible because `Google Cloud Platform does not allow `gcc` to be executed.
|
||||
|
||||
Please work only with compiled final binaries.
|
||||
|
||||
## Linux
|
||||
|
||||
To compile this package on Linux you must install the development tools for your linux distribution.
|
||||
|
||||
To compile under linux use the build tag `linux`.
|
||||
|
||||
```bash
|
||||
go build --tags "linux"
|
||||
```
|
||||
|
||||
If you wish to link directly to libsqlite3 then you can use the `libsqlite3` build tag.
|
||||
|
||||
```
|
||||
go build --tags "libsqlite3 linux"
|
||||
```
|
||||
|
||||
### Alpine
|
||||
|
||||
When building in an `alpine` container run the following command before building.
|
||||
|
||||
```
|
||||
apk add --update gcc musl-dev
|
||||
```
|
||||
|
||||
### Fedora
|
||||
|
||||
```bash
|
||||
sudo yum groupinstall "Development Tools" "Development Libraries"
|
||||
```
|
||||
|
||||
### Ubuntu
|
||||
|
||||
```bash
|
||||
sudo apt-get install build-essential
|
||||
```
|
||||
|
||||
## Mac OSX
|
||||
|
||||
OSX should have all the tools present to compile this package, if not install XCode this will add all the developers tools.
|
||||
|
||||
Required dependency
|
||||
|
||||
```bash
|
||||
brew install sqlite3
|
||||
```
|
||||
|
||||
For OSX there is an additional package install which is required if you whish to build the `icu` extension.
|
||||
|
||||
This additional package can be installed with `homebrew`.
|
||||
|
||||
```bash
|
||||
brew upgrade icu4c
|
||||
```
|
||||
|
||||
To compile for Mac OSX.
|
||||
|
||||
```bash
|
||||
go build --tags "darwin"
|
||||
```
|
||||
|
||||
If you wish to link directly to libsqlite3 then you can use the `libsqlite3` build tag.
|
||||
|
||||
```
|
||||
go build --tags "libsqlite3 darwin"
|
||||
```
|
||||
|
||||
Additional information:
|
||||
- [#206](https://github.com/mattn/go-sqlite3/issues/206)
|
||||
- [#404](https://github.com/mattn/go-sqlite3/issues/404)
|
||||
|
||||
## Windows
|
||||
|
||||
To compile this package on Windows OS you must have the `gcc` compiler installed.
|
||||
|
||||
1) Install a Windows `gcc` toolchain.
|
||||
2) Add the `bin` folders to the Windows path if the installer did not do this by default.
|
||||
3) Open a terminal for the TDM-GCC toolchain, can be found in the Windows Start menu.
|
||||
4) Navigate to your project folder and run the `go build ...` command for this package.
|
||||
|
||||
For example the TDM-GCC Toolchain can be found [here](ttps://sourceforge.net/projects/tdm-gcc/).
|
||||
|
||||
## Errors
|
||||
|
||||
- Compile error: `can not be used when making a shared object; recompile with -fPIC`
|
||||
|
||||
When receiving a compile time error referencing recompile with `-FPIC` then you
|
||||
are probably using a hardend system.
|
||||
|
||||
You can copile the library on a hardend system with the following command.
|
||||
|
||||
```bash
|
||||
go build -ldflags '-extldflags=-fno-PIC'
|
||||
```
|
||||
|
||||
More details see [#120](https://github.com/mattn/go-sqlite3/issues/120)
|
||||
|
||||
- Can't build go-sqlite3 on windows 64bit.
|
||||
|
||||
> Probably, you are using go 1.0, go1.0 has a problem when it comes to compiling/linking on windows 64bit.
|
||||
> See: [#27](https://github.com/mattn/go-sqlite3/issues/27)
|
||||
|
||||
- `go get github.com/mattn/go-sqlite3` throws compilation error.
|
||||
|
||||
`gcc` throws: `internal compiler error`
|
||||
|
||||
Remove the download repository from your disk and try re-install with:
|
||||
|
||||
```bash
|
||||
go install github.com/mattn/go-sqlite3
|
||||
```
|
||||
|
||||
# User Authentication
|
||||
|
||||
This package supports the SQLite User Authentication module.
|
||||
|
||||
## Compile
|
||||
|
||||
To use the User authentication module the package has to be compiled with the tag `sqlite_userauth`. See [Features](#features).
|
||||
|
||||
## Usage
|
||||
|
||||
### Create protected database
|
||||
|
||||
To create a database protected by user authentication provide the following argument to the connection string `_auth`.
|
||||
This will enable user authentication within the database. This option however requires two additional arguments:
|
||||
|
||||
- `_auth_user`
|
||||
- `_auth_pass`
|
||||
|
||||
When `_auth` is present on the connection string user authentication will be enabled and the provided user will be created
|
||||
as an `admin` user. After initial creation, the parameter `_auth` has no effect anymore and can be omitted from the connection string.
|
||||
|
||||
Example connection string:
|
||||
|
||||
Create an user authentication database with user `admin` and password `admin`.
|
||||
|
||||
`file:test.s3db?_auth&_auth_user=admin&_auth_pass=admin`
|
||||
|
||||
Create an user authentication database with user `admin` and password `admin` and use `SHA1` for the password encoding.
|
||||
|
||||
`file:test.s3db?_auth&_auth_user=admin&_auth_pass=admin&_auth_crypt=sha1`
|
||||
|
||||
### Password Encoding
|
||||
|
||||
The passwords within the user authentication module of SQLite are encoded with the SQLite function `sqlite_cryp`.
|
||||
This function uses a ceasar-cypher which is quite insecure.
|
||||
This library provides several additional password encoders which can be configured through the connection string.
|
||||
|
||||
The password cypher can be configured with the key `_auth_crypt`. And if the configured password encoder also requires an
|
||||
salt this can be configured with `_auth_salt`.
|
||||
|
||||
#### Available Encoders
|
||||
|
||||
- SHA1
|
||||
- SSHA1 (Salted SHA1)
|
||||
- SHA256
|
||||
- SSHA256 (salted SHA256)
|
||||
- SHA384
|
||||
- SSHA384 (salted SHA384)
|
||||
- SHA512
|
||||
- SSHA512 (salted SHA512)
|
||||
|
||||
### Restrictions
|
||||
|
||||
Operations on the database regarding to user management can only be preformed by an administrator user.
|
||||
|
||||
### Support
|
||||
|
||||
The user authentication supports two kinds of users
|
||||
|
||||
- administrators
|
||||
- regular users
|
||||
|
||||
### User Management
|
||||
|
||||
User management can be done by directly using the `*SQLiteConn` or by SQL.
|
||||
|
||||
#### SQL
|
||||
|
||||
The following sql functions are available for user management.
|
||||
|
||||
| Function | Arguments | Description |
|
||||
|----------|-----------|-------------|
|
||||
| `authenticate` | username `string`, password `string` | Will authenticate an user, this is done by the connection; and should not be used manually. |
|
||||
| `auth_user_add` | username `string`, password `string`, admin `int` | This function will add an user to the database.<br>if the database is not protected by user authentication it will enable it. Argument `admin` is an integer identifying if the added user should be an administrator. Only Administrators can add administrators. |
|
||||
| `auth_user_change` | username `string`, password `string`, admin `int` | Function to modify an user. Users can change their own password, but only an administrator can change the administrator flag. |
|
||||
| `authUserDelete` | username `string` | Delete an user from the database. Can only be used by an administrator. The current logged in administrator cannot be deleted. This is to make sure their is always an administrator remaining. |
|
||||
|
||||
These functions will return an integer.
|
||||
|
||||
- 0 (SQLITE_OK)
|
||||
- 23 (SQLITE_AUTH) Failed to perform due to authentication or insufficient privileges
|
||||
|
||||
##### Examples
|
||||
|
||||
```sql
|
||||
// Autheticate user
|
||||
// Create Admin User
|
||||
SELECT auth_user_add('admin2', 'admin2', 1);
|
||||
|
||||
// Change password for user
|
||||
SELECT auth_user_change('user', 'userpassword', 0);
|
||||
|
||||
// Delete user
|
||||
SELECT user_delete('user');
|
||||
```
|
||||
|
||||
#### *SQLiteConn
|
||||
|
||||
The following functions are available for User authentication from the `*SQLiteConn`.
|
||||
|
||||
| Function | Description |
|
||||
|----------|-------------|
|
||||
| `Authenticate(username, password string) error` | Authenticate user |
|
||||
| `AuthUserAdd(username, password string, admin bool) error` | Add user |
|
||||
| `AuthUserChange(username, password string, admin bool) error` | Modify user |
|
||||
| `AuthUserDelete(username string) error` | Delete user |
|
||||
|
||||
### Attached database
|
||||
|
||||
When using attached databases. SQLite will use the authentication from the `main` database for the attached database(s).
|
||||
|
||||
# Extensions
|
||||
|
||||
If you want your own extension to be listed here or you want to add a reference to an extension; please submit an Issue for this.
|
||||
|
||||
## Spatialite
|
||||
|
||||
Spatialite is available as an extension to SQLite, and can be used in combination with this repository.
|
||||
For an example see [shaxbee/go-spatialite](https://github.com/shaxbee/go-spatialite).
|
||||
|
||||
# FAQ
|
||||
|
||||
- Getting insert error while query is opened.
|
||||
|
||||
> You can pass some arguments into the connection string, for example, a URI.
|
||||
> See: [#39](https://github.com/mattn/go-sqlite3/issues/39)
|
||||
|
||||
- Do you want to cross compile? mingw on Linux or Mac?
|
||||
|
||||
> See: [#106](https://github.com/mattn/go-sqlite3/issues/106)
|
||||
> See also: http://www.limitlessfx.com/cross-compile-golang-app-for-windows-from-linux.html
|
||||
|
||||
- Want to get time.Time with current locale
|
||||
|
||||
Use `_loc=auto` in SQLite3 filename schema like `file:foo.db?_loc=auto`.
|
||||
|
||||
- Can I use this in multiple routines concurrently?
|
||||
|
||||
Yes for readonly. But, No for writable. See [#50](https://github.com/mattn/go-sqlite3/issues/50), [#51](https://github.com/mattn/go-sqlite3/issues/51), [#209](https://github.com/mattn/go-sqlite3/issues/209), [#274](https://github.com/mattn/go-sqlite3/issues/274).
|
||||
|
||||
- Why I'm getting `no such table` error?
|
||||
|
||||
Why is it racy if I use a `sql.Open("sqlite3", ":memory:")` database?
|
||||
|
||||
Each connection to :memory: opens a brand new in-memory sql database, so if
|
||||
the stdlib's sql engine happens to open another connection and you've only
|
||||
specified ":memory:", that connection will see a brand new database. A
|
||||
workaround is to use "file::memory:?mode=memory&cache=shared". Every
|
||||
connection to this string will point to the same in-memory database.
|
||||
|
||||
For more information see
|
||||
* [#204](https://github.com/mattn/go-sqlite3/issues/204)
|
||||
* [#511](https://github.com/mattn/go-sqlite3/issues/511)
|
||||
|
||||
- Reading from database with large amount of goroutines fails on OSX.
|
||||
|
||||
OS X limits OS-wide to not have more than 1000 files open simultaneously by default.
|
||||
|
||||
For more information see [#289](https://github.com/mattn/go-sqlite3/issues/289)
|
||||
|
||||
- Trying to execure a `.` (dot) command throws an error.
|
||||
|
||||
Error: `Error: near ".": syntax error`
|
||||
Dot command are part of SQLite3 CLI not of this library.
|
||||
|
||||
You need to implement the feature or call the sqlite3 cli.
|
||||
|
||||
More infomation see [#305](https://github.com/mattn/go-sqlite3/issues/305)
|
||||
|
||||
- Error: `database is locked`
|
||||
|
||||
When you get an database is locked. Please use the following options.
|
||||
|
||||
Add to DSN: `cache=shared`
|
||||
|
||||
Example:
|
||||
```go
|
||||
db, err := sql.Open("sqlite3", "file:locked.sqlite?cache=shared")
|
||||
```
|
||||
|
||||
Second please set the database connections of the SQL package to 1.
|
||||
|
||||
```go
|
||||
db.SetMaxOpenConn(1)
|
||||
```
|
||||
|
||||
More information see [#209](https://github.com/mattn/go-sqlite3/issues/209)
|
||||
|
||||
# License
|
||||
|
||||
MIT: http://mattn.mit-license.org/2018
|
||||
|
||||
sqlite3-binding.c, sqlite3-binding.h, sqlite3ext.h
|
||||
|
||||
The -binding suffix was added to avoid build failures under gccgo.
|
||||
|
||||
In this repository, those files are an amalgamation of code that was copied from SQLite3. The license of that code is the same as the license of SQLite3.
|
||||
|
||||
# Author
|
||||
|
||||
Yasuhiro Matsumoto (a.k.a mattn)
|
||||
|
||||
G.J.R. Timmer
|
85
vendor/github.com/mattn/go-sqlite3/backup.go
generated
vendored
85
vendor/github.com/mattn/go-sqlite3/backup.go
generated
vendored
@ -1,85 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// SQLiteBackup implement interface of Backup.
|
||||
type SQLiteBackup struct {
|
||||
b *C.sqlite3_backup
|
||||
}
|
||||
|
||||
// Backup make backup from src to dest.
|
||||
func (c *SQLiteConn) Backup(dest string, conn *SQLiteConn, src string) (*SQLiteBackup, error) {
|
||||
destptr := C.CString(dest)
|
||||
defer C.free(unsafe.Pointer(destptr))
|
||||
srcptr := C.CString(src)
|
||||
defer C.free(unsafe.Pointer(srcptr))
|
||||
|
||||
if b := C.sqlite3_backup_init(c.db, destptr, conn.db, srcptr); b != nil {
|
||||
bb := &SQLiteBackup{b: b}
|
||||
runtime.SetFinalizer(bb, (*SQLiteBackup).Finish)
|
||||
return bb, nil
|
||||
}
|
||||
return nil, c.lastError()
|
||||
}
|
||||
|
||||
// Step to backs up for one step. Calls the underlying `sqlite3_backup_step`
|
||||
// function. This function returns a boolean indicating if the backup is done
|
||||
// and an error signalling any other error. Done is returned if the underlying
|
||||
// C function returns SQLITE_DONE (Code 101)
|
||||
func (b *SQLiteBackup) Step(p int) (bool, error) {
|
||||
ret := C.sqlite3_backup_step(b.b, C.int(p))
|
||||
if ret == C.SQLITE_DONE {
|
||||
return true, nil
|
||||
} else if ret != 0 && ret != C.SQLITE_LOCKED && ret != C.SQLITE_BUSY {
|
||||
return false, Error{Code: ErrNo(ret)}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Remaining return whether have the rest for backup.
|
||||
func (b *SQLiteBackup) Remaining() int {
|
||||
return int(C.sqlite3_backup_remaining(b.b))
|
||||
}
|
||||
|
||||
// PageCount return count of pages.
|
||||
func (b *SQLiteBackup) PageCount() int {
|
||||
return int(C.sqlite3_backup_pagecount(b.b))
|
||||
}
|
||||
|
||||
// Finish close backup.
|
||||
func (b *SQLiteBackup) Finish() error {
|
||||
return b.Close()
|
||||
}
|
||||
|
||||
// Close close backup.
|
||||
func (b *SQLiteBackup) Close() error {
|
||||
ret := C.sqlite3_backup_finish(b.b)
|
||||
|
||||
// sqlite3_backup_finish() never fails, it just returns the
|
||||
// error code from previous operations, so clean up before
|
||||
// checking and returning an error
|
||||
b.b = nil
|
||||
runtime.SetFinalizer(b, nil)
|
||||
|
||||
if ret != 0 {
|
||||
return Error{Code: ErrNo(ret)}
|
||||
}
|
||||
return nil
|
||||
}
|
374
vendor/github.com/mattn/go-sqlite3/callback.go
generated
vendored
374
vendor/github.com/mattn/go-sqlite3/callback.go
generated
vendored
@ -1,374 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqlite3
|
||||
|
||||
// You can't export a Go function to C and have definitions in the C
|
||||
// preamble in the same file, so we have to have callbackTrampoline in
|
||||
// its own file. Because we need a separate file anyway, the support
|
||||
// code for SQLite custom functions is in here.
|
||||
|
||||
/*
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
|
||||
void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
|
||||
void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
//export callbackTrampoline
|
||||
func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
|
||||
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
|
||||
fi := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*functionInfo)
|
||||
fi.Call(ctx, args)
|
||||
}
|
||||
|
||||
//export stepTrampoline
|
||||
func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
|
||||
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
|
||||
ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo)
|
||||
ai.Step(ctx, args)
|
||||
}
|
||||
|
||||
//export doneTrampoline
|
||||
func doneTrampoline(ctx *C.sqlite3_context) {
|
||||
handle := uintptr(C.sqlite3_user_data(ctx))
|
||||
ai := lookupHandle(handle).(*aggInfo)
|
||||
ai.Done(ctx)
|
||||
}
|
||||
|
||||
//export compareTrampoline
|
||||
func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.char) C.int {
|
||||
cmp := lookupHandle(handlePtr).(func(string, string) int)
|
||||
return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
|
||||
}
|
||||
|
||||
//export commitHookTrampoline
|
||||
func commitHookTrampoline(handle uintptr) int {
|
||||
callback := lookupHandle(handle).(func() int)
|
||||
return callback()
|
||||
}
|
||||
|
||||
//export rollbackHookTrampoline
|
||||
func rollbackHookTrampoline(handle uintptr) {
|
||||
callback := lookupHandle(handle).(func())
|
||||
callback()
|
||||
}
|
||||
|
||||
//export updateHookTrampoline
|
||||
func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) {
|
||||
callback := lookupHandle(handle).(func(int, string, string, int64))
|
||||
callback(op, C.GoString(db), C.GoString(table), rowid)
|
||||
}
|
||||
|
||||
// Use handles to avoid passing Go pointers to C.
|
||||
|
||||
type handleVal struct {
|
||||
db *SQLiteConn
|
||||
val interface{}
|
||||
}
|
||||
|
||||
var handleLock sync.Mutex
|
||||
var handleVals = make(map[uintptr]handleVal)
|
||||
var handleIndex uintptr = 100
|
||||
|
||||
func newHandle(db *SQLiteConn, v interface{}) uintptr {
|
||||
handleLock.Lock()
|
||||
defer handleLock.Unlock()
|
||||
i := handleIndex
|
||||
handleIndex++
|
||||
handleVals[i] = handleVal{db, v}
|
||||
return i
|
||||
}
|
||||
|
||||
func lookupHandle(handle uintptr) interface{} {
|
||||
handleLock.Lock()
|
||||
defer handleLock.Unlock()
|
||||
r, ok := handleVals[handle]
|
||||
if !ok {
|
||||
if handle >= 100 && handle < handleIndex {
|
||||
panic("deleted handle")
|
||||
} else {
|
||||
panic("invalid handle")
|
||||
}
|
||||
}
|
||||
return r.val
|
||||
}
|
||||
|
||||
func deleteHandles(db *SQLiteConn) {
|
||||
handleLock.Lock()
|
||||
defer handleLock.Unlock()
|
||||
for handle, val := range handleVals {
|
||||
if val.db == db {
|
||||
delete(handleVals, handle)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This is only here so that tests can refer to it.
|
||||
type callbackArgRaw C.sqlite3_value
|
||||
|
||||
type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
|
||||
|
||||
type callbackArgCast struct {
|
||||
f callbackArgConverter
|
||||
typ reflect.Type
|
||||
}
|
||||
|
||||
func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
val, err := c.f(v)
|
||||
if err != nil {
|
||||
return reflect.Value{}, err
|
||||
}
|
||||
if !val.Type().ConvertibleTo(c.typ) {
|
||||
return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
|
||||
}
|
||||
return val.Convert(c.typ), nil
|
||||
}
|
||||
|
||||
func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
|
||||
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
|
||||
}
|
||||
return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
|
||||
}
|
||||
|
||||
func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
|
||||
return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
|
||||
}
|
||||
i := int64(C.sqlite3_value_int64(v))
|
||||
val := false
|
||||
if i != 0 {
|
||||
val = true
|
||||
}
|
||||
return reflect.ValueOf(val), nil
|
||||
}
|
||||
|
||||
func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
|
||||
return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
|
||||
}
|
||||
return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
|
||||
}
|
||||
|
||||
func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
switch C.sqlite3_value_type(v) {
|
||||
case C.SQLITE_BLOB:
|
||||
l := C.sqlite3_value_bytes(v)
|
||||
p := C.sqlite3_value_blob(v)
|
||||
return reflect.ValueOf(C.GoBytes(p, l)), nil
|
||||
case C.SQLITE_TEXT:
|
||||
l := C.sqlite3_value_bytes(v)
|
||||
c := unsafe.Pointer(C.sqlite3_value_text(v))
|
||||
return reflect.ValueOf(C.GoBytes(c, l)), nil
|
||||
default:
|
||||
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
|
||||
}
|
||||
}
|
||||
|
||||
func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
switch C.sqlite3_value_type(v) {
|
||||
case C.SQLITE_BLOB:
|
||||
l := C.sqlite3_value_bytes(v)
|
||||
p := (*C.char)(C.sqlite3_value_blob(v))
|
||||
return reflect.ValueOf(C.GoStringN(p, l)), nil
|
||||
case C.SQLITE_TEXT:
|
||||
c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
|
||||
return reflect.ValueOf(C.GoString(c)), nil
|
||||
default:
|
||||
return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
|
||||
}
|
||||
}
|
||||
|
||||
func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) {
|
||||
switch C.sqlite3_value_type(v) {
|
||||
case C.SQLITE_INTEGER:
|
||||
return callbackArgInt64(v)
|
||||
case C.SQLITE_FLOAT:
|
||||
return callbackArgFloat64(v)
|
||||
case C.SQLITE_TEXT:
|
||||
return callbackArgString(v)
|
||||
case C.SQLITE_BLOB:
|
||||
return callbackArgBytes(v)
|
||||
case C.SQLITE_NULL:
|
||||
// Interpret NULL as a nil byte slice.
|
||||
var ret []byte
|
||||
return reflect.ValueOf(ret), nil
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
|
||||
switch typ.Kind() {
|
||||
case reflect.Interface:
|
||||
if typ.NumMethod() != 0 {
|
||||
return nil, errors.New("the only supported interface type is interface{}")
|
||||
}
|
||||
return callbackArgGeneric, nil
|
||||
case reflect.Slice:
|
||||
if typ.Elem().Kind() != reflect.Uint8 {
|
||||
return nil, errors.New("the only supported slice type is []byte")
|
||||
}
|
||||
return callbackArgBytes, nil
|
||||
case reflect.String:
|
||||
return callbackArgString, nil
|
||||
case reflect.Bool:
|
||||
return callbackArgBool, nil
|
||||
case reflect.Int64:
|
||||
return callbackArgInt64, nil
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
|
||||
c := callbackArgCast{callbackArgInt64, typ}
|
||||
return c.Run, nil
|
||||
case reflect.Float64:
|
||||
return callbackArgFloat64, nil
|
||||
case reflect.Float32:
|
||||
c := callbackArgCast{callbackArgFloat64, typ}
|
||||
return c.Run, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("don't know how to convert to %s", typ)
|
||||
}
|
||||
}
|
||||
|
||||
func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) {
|
||||
var args []reflect.Value
|
||||
|
||||
if len(argv) < len(converters) {
|
||||
return nil, fmt.Errorf("function requires at least %d arguments", len(converters))
|
||||
}
|
||||
|
||||
for i, arg := range argv[:len(converters)] {
|
||||
v, err := converters[i](arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, v)
|
||||
}
|
||||
|
||||
if variadic != nil {
|
||||
for _, arg := range argv[len(converters):] {
|
||||
v, err := variadic(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, v)
|
||||
}
|
||||
}
|
||||
return args, nil
|
||||
}
|
||||
|
||||
type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
|
||||
|
||||
func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Int64:
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
|
||||
v = v.Convert(reflect.TypeOf(int64(0)))
|
||||
case reflect.Bool:
|
||||
b := v.Interface().(bool)
|
||||
if b {
|
||||
v = reflect.ValueOf(int64(1))
|
||||
} else {
|
||||
v = reflect.ValueOf(int64(0))
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
|
||||
}
|
||||
|
||||
C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||
switch v.Type().Kind() {
|
||||
case reflect.Float64:
|
||||
case reflect.Float32:
|
||||
v = v.Convert(reflect.TypeOf(float64(0)))
|
||||
default:
|
||||
return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
|
||||
}
|
||||
|
||||
C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||
if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return fmt.Errorf("cannot convert %s to BLOB", v.Type())
|
||||
}
|
||||
i := v.Interface()
|
||||
if i == nil || len(i.([]byte)) == 0 {
|
||||
C.sqlite3_result_null(ctx)
|
||||
} else {
|
||||
bs := i.([]byte)
|
||||
C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||
if v.Type().Kind() != reflect.String {
|
||||
return fmt.Errorf("cannot convert %s to TEXT", v.Type())
|
||||
}
|
||||
C._sqlite3_result_text(ctx, C.CString(v.Interface().(string)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
|
||||
switch typ.Kind() {
|
||||
case reflect.Interface:
|
||||
errorInterface := reflect.TypeOf((*error)(nil)).Elem()
|
||||
if typ.Implements(errorInterface) {
|
||||
return callbackRetNil, nil
|
||||
}
|
||||
fallthrough
|
||||
case reflect.Slice:
|
||||
if typ.Elem().Kind() != reflect.Uint8 {
|
||||
return nil, errors.New("the only supported slice type is []byte")
|
||||
}
|
||||
return callbackRetBlob, nil
|
||||
case reflect.String:
|
||||
return callbackRetText, nil
|
||||
case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
|
||||
return callbackRetInteger, nil
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return callbackRetFloat, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("don't know how to convert to %s", typ)
|
||||
}
|
||||
}
|
||||
|
||||
func callbackError(ctx *C.sqlite3_context, err error) {
|
||||
cstr := C.CString(err.Error())
|
||||
defer C.free(unsafe.Pointer(cstr))
|
||||
C.sqlite3_result_error(ctx, cstr, -1)
|
||||
}
|
||||
|
||||
// Test support code. Tests are not allowed to import "C", so we can't
|
||||
// declare any functions that use C.sqlite3_value.
|
||||
func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
|
||||
return func(*C.sqlite3_value) (reflect.Value, error) {
|
||||
return v, err
|
||||
}
|
||||
}
|
112
vendor/github.com/mattn/go-sqlite3/doc.go
generated
vendored
112
vendor/github.com/mattn/go-sqlite3/doc.go
generated
vendored
@ -1,112 +0,0 @@
|
||||
/*
|
||||
Package sqlite3 provides interface to SQLite3 databases.
|
||||
|
||||
This works as a driver for database/sql.
|
||||
|
||||
Installation
|
||||
|
||||
go get github.com/mattn/go-sqlite3
|
||||
|
||||
Supported Types
|
||||
|
||||
Currently, go-sqlite3 supports the following data types.
|
||||
|
||||
+------------------------------+
|
||||
|go | sqlite3 |
|
||||
|----------|-------------------|
|
||||
|nil | null |
|
||||
|int | integer |
|
||||
|int64 | integer |
|
||||
|float64 | float |
|
||||
|bool | integer |
|
||||
|[]byte | blob |
|
||||
|string | text |
|
||||
|time.Time | timestamp/datetime|
|
||||
+------------------------------+
|
||||
|
||||
SQLite3 Extension
|
||||
|
||||
You can write your own extension module for sqlite3. For example, below is an
|
||||
extension for a Regexp matcher operation.
|
||||
|
||||
#include <pcre.h>
|
||||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
#include <sqlite3ext.h>
|
||||
|
||||
SQLITE_EXTENSION_INIT1
|
||||
static void regexp_func(sqlite3_context *context, int argc, sqlite3_value **argv) {
|
||||
if (argc >= 2) {
|
||||
const char *target = (const char *)sqlite3_value_text(argv[1]);
|
||||
const char *pattern = (const char *)sqlite3_value_text(argv[0]);
|
||||
const char* errstr = NULL;
|
||||
int erroff = 0;
|
||||
int vec[500];
|
||||
int n, rc;
|
||||
pcre* re = pcre_compile(pattern, 0, &errstr, &erroff, NULL);
|
||||
rc = pcre_exec(re, NULL, target, strlen(target), 0, 0, vec, 500);
|
||||
if (rc <= 0) {
|
||||
sqlite3_result_error(context, errstr, 0);
|
||||
return;
|
||||
}
|
||||
sqlite3_result_int(context, 1);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef _WIN32
|
||||
__declspec(dllexport)
|
||||
#endif
|
||||
int sqlite3_extension_init(sqlite3 *db, char **errmsg,
|
||||
const sqlite3_api_routines *api) {
|
||||
SQLITE_EXTENSION_INIT2(api);
|
||||
return sqlite3_create_function(db, "regexp", 2, SQLITE_UTF8,
|
||||
(void*)db, regexp_func, NULL, NULL);
|
||||
}
|
||||
|
||||
It needs to be built as a so/dll shared library. And you need to register
|
||||
the extension module like below.
|
||||
|
||||
sql.Register("sqlite3_with_extensions",
|
||||
&sqlite3.SQLiteDriver{
|
||||
Extensions: []string{
|
||||
"sqlite3_mod_regexp",
|
||||
},
|
||||
})
|
||||
|
||||
Then, you can use this extension.
|
||||
|
||||
rows, err := db.Query("select text from mytable where name regexp '^golang'")
|
||||
|
||||
Connection Hook
|
||||
|
||||
You can hook and inject your code when the connection is established. database/sql
|
||||
doesn't provide a way to get native go-sqlite3 interfaces. So if you want,
|
||||
you need to set ConnectHook and get the SQLiteConn.
|
||||
|
||||
sql.Register("sqlite3_with_hook_example",
|
||||
&sqlite3.SQLiteDriver{
|
||||
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||
sqlite3conn = append(sqlite3conn, conn)
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
Go SQlite3 Extensions
|
||||
|
||||
If you want to register Go functions as SQLite extension functions,
|
||||
call RegisterFunction from ConnectHook.
|
||||
|
||||
regex = func(re, s string) (bool, error) {
|
||||
return regexp.MatchString(re, s)
|
||||
}
|
||||
sql.Register("sqlite3_with_go_func",
|
||||
&sqlite3.SQLiteDriver{
|
||||
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
|
||||
return conn.RegisterFunc("regexp", regex, true)
|
||||
},
|
||||
})
|
||||
|
||||
See the documentation of RegisterFunc for more details.
|
||||
|
||||
*/
|
||||
package sqlite3
|
135
vendor/github.com/mattn/go-sqlite3/error.go
generated
vendored
135
vendor/github.com/mattn/go-sqlite3/error.go
generated
vendored
@ -1,135 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import "C"
|
||||
|
||||
// ErrNo inherit errno.
|
||||
type ErrNo int
|
||||
|
||||
// ErrNoMask is mask code.
|
||||
const ErrNoMask C.int = 0xff
|
||||
|
||||
// ErrNoExtended is extended errno.
|
||||
type ErrNoExtended int
|
||||
|
||||
// Error implement sqlite error code.
|
||||
type Error struct {
|
||||
Code ErrNo /* The error code returned by SQLite */
|
||||
ExtendedCode ErrNoExtended /* The extended error code returned by SQLite */
|
||||
err string /* The error string returned by sqlite3_errmsg(),
|
||||
this usually contains more specific details. */
|
||||
}
|
||||
|
||||
// result codes from http://www.sqlite.org/c3ref/c_abort.html
|
||||
var (
|
||||
ErrError = ErrNo(1) /* SQL error or missing database */
|
||||
ErrInternal = ErrNo(2) /* Internal logic error in SQLite */
|
||||
ErrPerm = ErrNo(3) /* Access permission denied */
|
||||
ErrAbort = ErrNo(4) /* Callback routine requested an abort */
|
||||
ErrBusy = ErrNo(5) /* The database file is locked */
|
||||
ErrLocked = ErrNo(6) /* A table in the database is locked */
|
||||
ErrNomem = ErrNo(7) /* A malloc() failed */
|
||||
ErrReadonly = ErrNo(8) /* Attempt to write a readonly database */
|
||||
ErrInterrupt = ErrNo(9) /* Operation terminated by sqlite3_interrupt() */
|
||||
ErrIoErr = ErrNo(10) /* Some kind of disk I/O error occurred */
|
||||
ErrCorrupt = ErrNo(11) /* The database disk image is malformed */
|
||||
ErrNotFound = ErrNo(12) /* Unknown opcode in sqlite3_file_control() */
|
||||
ErrFull = ErrNo(13) /* Insertion failed because database is full */
|
||||
ErrCantOpen = ErrNo(14) /* Unable to open the database file */
|
||||
ErrProtocol = ErrNo(15) /* Database lock protocol error */
|
||||
ErrEmpty = ErrNo(16) /* Database is empty */
|
||||
ErrSchema = ErrNo(17) /* The database schema changed */
|
||||
ErrTooBig = ErrNo(18) /* String or BLOB exceeds size limit */
|
||||
ErrConstraint = ErrNo(19) /* Abort due to constraint violation */
|
||||
ErrMismatch = ErrNo(20) /* Data type mismatch */
|
||||
ErrMisuse = ErrNo(21) /* Library used incorrectly */
|
||||
ErrNoLFS = ErrNo(22) /* Uses OS features not supported on host */
|
||||
ErrAuth = ErrNo(23) /* Authorization denied */
|
||||
ErrFormat = ErrNo(24) /* Auxiliary database format error */
|
||||
ErrRange = ErrNo(25) /* 2nd parameter to sqlite3_bind out of range */
|
||||
ErrNotADB = ErrNo(26) /* File opened that is not a database file */
|
||||
ErrNotice = ErrNo(27) /* Notifications from sqlite3_log() */
|
||||
ErrWarning = ErrNo(28) /* Warnings from sqlite3_log() */
|
||||
)
|
||||
|
||||
// Error return error message from errno.
|
||||
func (err ErrNo) Error() string {
|
||||
return Error{Code: err}.Error()
|
||||
}
|
||||
|
||||
// Extend return extended errno.
|
||||
func (err ErrNo) Extend(by int) ErrNoExtended {
|
||||
return ErrNoExtended(int(err) | (by << 8))
|
||||
}
|
||||
|
||||
// Error return error message that is extended code.
|
||||
func (err ErrNoExtended) Error() string {
|
||||
return Error{Code: ErrNo(C.int(err) & ErrNoMask), ExtendedCode: err}.Error()
|
||||
}
|
||||
|
||||
func (err Error) Error() string {
|
||||
if err.err != "" {
|
||||
return err.err
|
||||
}
|
||||
return errorString(err)
|
||||
}
|
||||
|
||||
// result codes from http://www.sqlite.org/c3ref/c_abort_rollback.html
|
||||
var (
|
||||
ErrIoErrRead = ErrIoErr.Extend(1)
|
||||
ErrIoErrShortRead = ErrIoErr.Extend(2)
|
||||
ErrIoErrWrite = ErrIoErr.Extend(3)
|
||||
ErrIoErrFsync = ErrIoErr.Extend(4)
|
||||
ErrIoErrDirFsync = ErrIoErr.Extend(5)
|
||||
ErrIoErrTruncate = ErrIoErr.Extend(6)
|
||||
ErrIoErrFstat = ErrIoErr.Extend(7)
|
||||
ErrIoErrUnlock = ErrIoErr.Extend(8)
|
||||
ErrIoErrRDlock = ErrIoErr.Extend(9)
|
||||
ErrIoErrDelete = ErrIoErr.Extend(10)
|
||||
ErrIoErrBlocked = ErrIoErr.Extend(11)
|
||||
ErrIoErrNoMem = ErrIoErr.Extend(12)
|
||||
ErrIoErrAccess = ErrIoErr.Extend(13)
|
||||
ErrIoErrCheckReservedLock = ErrIoErr.Extend(14)
|
||||
ErrIoErrLock = ErrIoErr.Extend(15)
|
||||
ErrIoErrClose = ErrIoErr.Extend(16)
|
||||
ErrIoErrDirClose = ErrIoErr.Extend(17)
|
||||
ErrIoErrSHMOpen = ErrIoErr.Extend(18)
|
||||
ErrIoErrSHMSize = ErrIoErr.Extend(19)
|
||||
ErrIoErrSHMLock = ErrIoErr.Extend(20)
|
||||
ErrIoErrSHMMap = ErrIoErr.Extend(21)
|
||||
ErrIoErrSeek = ErrIoErr.Extend(22)
|
||||
ErrIoErrDeleteNoent = ErrIoErr.Extend(23)
|
||||
ErrIoErrMMap = ErrIoErr.Extend(24)
|
||||
ErrIoErrGetTempPath = ErrIoErr.Extend(25)
|
||||
ErrIoErrConvPath = ErrIoErr.Extend(26)
|
||||
ErrLockedSharedCache = ErrLocked.Extend(1)
|
||||
ErrBusyRecovery = ErrBusy.Extend(1)
|
||||
ErrBusySnapshot = ErrBusy.Extend(2)
|
||||
ErrCantOpenNoTempDir = ErrCantOpen.Extend(1)
|
||||
ErrCantOpenIsDir = ErrCantOpen.Extend(2)
|
||||
ErrCantOpenFullPath = ErrCantOpen.Extend(3)
|
||||
ErrCantOpenConvPath = ErrCantOpen.Extend(4)
|
||||
ErrCorruptVTab = ErrCorrupt.Extend(1)
|
||||
ErrReadonlyRecovery = ErrReadonly.Extend(1)
|
||||
ErrReadonlyCantLock = ErrReadonly.Extend(2)
|
||||
ErrReadonlyRollback = ErrReadonly.Extend(3)
|
||||
ErrReadonlyDbMoved = ErrReadonly.Extend(4)
|
||||
ErrAbortRollback = ErrAbort.Extend(2)
|
||||
ErrConstraintCheck = ErrConstraint.Extend(1)
|
||||
ErrConstraintCommitHook = ErrConstraint.Extend(2)
|
||||
ErrConstraintForeignKey = ErrConstraint.Extend(3)
|
||||
ErrConstraintFunction = ErrConstraint.Extend(4)
|
||||
ErrConstraintNotNull = ErrConstraint.Extend(5)
|
||||
ErrConstraintPrimaryKey = ErrConstraint.Extend(6)
|
||||
ErrConstraintTrigger = ErrConstraint.Extend(7)
|
||||
ErrConstraintUnique = ErrConstraint.Extend(8)
|
||||
ErrConstraintVTab = ErrConstraint.Extend(9)
|
||||
ErrConstraintRowID = ErrConstraint.Extend(10)
|
||||
ErrNoticeRecoverWAL = ErrNotice.Extend(1)
|
||||
ErrNoticeRecoverRollback = ErrNotice.Extend(2)
|
||||
ErrWarningAutoIndex = ErrWarning.Extend(1)
|
||||
)
|
212224
vendor/github.com/mattn/go-sqlite3/sqlite3-binding.c
generated
vendored
212224
vendor/github.com/mattn/go-sqlite3/sqlite3-binding.c
generated
vendored
File diff suppressed because it is too large
Load Diff
11535
vendor/github.com/mattn/go-sqlite3/sqlite3-binding.h
generated
vendored
11535
vendor/github.com/mattn/go-sqlite3/sqlite3-binding.h
generated
vendored
File diff suppressed because it is too large
Load Diff
1979
vendor/github.com/mattn/go-sqlite3/sqlite3.go
generated
vendored
1979
vendor/github.com/mattn/go-sqlite3/sqlite3.go
generated
vendored
File diff suppressed because it is too large
Load Diff
103
vendor/github.com/mattn/go-sqlite3/sqlite3_context.go
generated
vendored
103
vendor/github.com/mattn/go-sqlite3/sqlite3_context.go
generated
vendored
@ -1,103 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
// These wrappers are necessary because SQLITE_TRANSIENT
|
||||
// is a pointer constant, and cgo doesn't translate them correctly.
|
||||
|
||||
static inline void my_result_text(sqlite3_context *ctx, char *p, int np) {
|
||||
sqlite3_result_text(ctx, p, np, SQLITE_TRANSIENT);
|
||||
}
|
||||
|
||||
static inline void my_result_blob(sqlite3_context *ctx, void *p, int np) {
|
||||
sqlite3_result_blob(ctx, p, np, SQLITE_TRANSIENT);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"math"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const i64 = unsafe.Sizeof(int(0)) > 4
|
||||
|
||||
// SQLiteContext behave sqlite3_context
|
||||
type SQLiteContext C.sqlite3_context
|
||||
|
||||
// ResultBool sets the result of an SQL function.
|
||||
func (c *SQLiteContext) ResultBool(b bool) {
|
||||
if b {
|
||||
c.ResultInt(1)
|
||||
} else {
|
||||
c.ResultInt(0)
|
||||
}
|
||||
}
|
||||
|
||||
// ResultBlob sets the result of an SQL function.
|
||||
// See: sqlite3_result_blob, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultBlob(b []byte) {
|
||||
if i64 && len(b) > math.MaxInt32 {
|
||||
C.sqlite3_result_error_toobig((*C.sqlite3_context)(c))
|
||||
return
|
||||
}
|
||||
var p *byte
|
||||
if len(b) > 0 {
|
||||
p = &b[0]
|
||||
}
|
||||
C.my_result_blob((*C.sqlite3_context)(c), unsafe.Pointer(p), C.int(len(b)))
|
||||
}
|
||||
|
||||
// ResultDouble sets the result of an SQL function.
|
||||
// See: sqlite3_result_double, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultDouble(d float64) {
|
||||
C.sqlite3_result_double((*C.sqlite3_context)(c), C.double(d))
|
||||
}
|
||||
|
||||
// ResultInt sets the result of an SQL function.
|
||||
// See: sqlite3_result_int, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultInt(i int) {
|
||||
if i64 && (i > math.MaxInt32 || i < math.MinInt32) {
|
||||
C.sqlite3_result_int64((*C.sqlite3_context)(c), C.sqlite3_int64(i))
|
||||
} else {
|
||||
C.sqlite3_result_int((*C.sqlite3_context)(c), C.int(i))
|
||||
}
|
||||
}
|
||||
|
||||
// ResultInt64 sets the result of an SQL function.
|
||||
// See: sqlite3_result_int64, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultInt64(i int64) {
|
||||
C.sqlite3_result_int64((*C.sqlite3_context)(c), C.sqlite3_int64(i))
|
||||
}
|
||||
|
||||
// ResultNull sets the result of an SQL function.
|
||||
// See: sqlite3_result_null, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultNull() {
|
||||
C.sqlite3_result_null((*C.sqlite3_context)(c))
|
||||
}
|
||||
|
||||
// ResultText sets the result of an SQL function.
|
||||
// See: sqlite3_result_text, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultText(s string) {
|
||||
h := (*reflect.StringHeader)(unsafe.Pointer(&s))
|
||||
cs, l := (*C.char)(unsafe.Pointer(h.Data)), C.int(h.Len)
|
||||
C.my_result_text((*C.sqlite3_context)(c), cs, l)
|
||||
}
|
||||
|
||||
// ResultZeroblob sets the result of an SQL function.
|
||||
// See: sqlite3_result_zeroblob, http://sqlite.org/c3ref/result_blob.html
|
||||
func (c *SQLiteContext) ResultZeroblob(n int) {
|
||||
C.sqlite3_result_zeroblob((*C.sqlite3_context)(c), C.int(n))
|
||||
}
|
120
vendor/github.com/mattn/go-sqlite3/sqlite3_func_crypt.go
generated
vendored
120
vendor/github.com/mattn/go-sqlite3/sqlite3_func_crypt.go
generated
vendored
@ -1,120 +0,0 @@
|
||||
// Copyright (C) 2018 G.J.R. Timmer <gjr.timmer@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
)
|
||||
|
||||
// This file provides several different implementations for the
|
||||
// default embedded sqlite_crypt function.
|
||||
// This function is uses a ceasar-cypher by default
|
||||
// and is used within the UserAuthentication module to encode
|
||||
// the password.
|
||||
//
|
||||
// The provided functions can be used as an overload to the sqlite_crypt
|
||||
// function through the use of the RegisterFunc on the connection.
|
||||
//
|
||||
// Because the functions can serv a purpose to an end-user
|
||||
// without using the UserAuthentication module
|
||||
// the functions are default compiled in.
|
||||
//
|
||||
// From SQLITE3 - user-auth.txt
|
||||
// The sqlite_user.pw field is encoded by a built-in SQL function
|
||||
// "sqlite_crypt(X,Y)". The two arguments are both BLOBs. The first argument
|
||||
// is the plaintext password supplied to the sqlite3_user_authenticate()
|
||||
// interface. The second argument is the sqlite_user.pw value and is supplied
|
||||
// so that the function can extract the "salt" used by the password encoder.
|
||||
// The result of sqlite_crypt(X,Y) is another blob which is the value that
|
||||
// ends up being stored in sqlite_user.pw. To verify credentials X supplied
|
||||
// by the sqlite3_user_authenticate() routine, SQLite runs:
|
||||
//
|
||||
// sqlite_user.pw == sqlite_crypt(X, sqlite_user.pw)
|
||||
//
|
||||
// To compute an appropriate sqlite_user.pw value from a new or modified
|
||||
// password X, sqlite_crypt(X,NULL) is run. A new random salt is selected
|
||||
// when the second argument is NULL.
|
||||
//
|
||||
// The built-in version of of sqlite_crypt() uses a simple Ceasar-cypher
|
||||
// which prevents passwords from being revealed by searching the raw database
|
||||
// for ASCII text, but is otherwise trivally broken. For better password
|
||||
// security, the database should be encrypted using the SQLite Encryption
|
||||
// Extension or similar technology. Or, the application can use the
|
||||
// sqlite3_create_function() interface to provide an alternative
|
||||
// implementation of sqlite_crypt() that computes a stronger password hash,
|
||||
// perhaps using a cryptographic hash function like SHA1.
|
||||
|
||||
// CryptEncoderSHA1 encodes a password with SHA1
|
||||
func CryptEncoderSHA1(pass []byte, hash interface{}) []byte {
|
||||
h := sha1.Sum(pass)
|
||||
return h[:]
|
||||
}
|
||||
|
||||
// CryptEncoderSSHA1 encodes a password with SHA1 with the
|
||||
// configured salt.
|
||||
func CryptEncoderSSHA1(salt string) func(pass []byte, hash interface{}) []byte {
|
||||
return func(pass []byte, hash interface{}) []byte {
|
||||
s := []byte(salt)
|
||||
p := append(pass, s...)
|
||||
h := sha1.Sum(p)
|
||||
return h[:]
|
||||
}
|
||||
}
|
||||
|
||||
// CryptEncoderSHA256 encodes a password with SHA256
|
||||
func CryptEncoderSHA256(pass []byte, hash interface{}) []byte {
|
||||
h := sha256.Sum256(pass)
|
||||
return h[:]
|
||||
}
|
||||
|
||||
// CryptEncoderSSHA256 encodes a password with SHA256
|
||||
// with the configured salt
|
||||
func CryptEncoderSSHA256(salt string) func(pass []byte, hash interface{}) []byte {
|
||||
return func(pass []byte, hash interface{}) []byte {
|
||||
s := []byte(salt)
|
||||
p := append(pass, s...)
|
||||
h := sha256.Sum256(p)
|
||||
return h[:]
|
||||
}
|
||||
}
|
||||
|
||||
// CryptEncoderSHA384 encodes a password with SHA256
|
||||
func CryptEncoderSHA384(pass []byte, hash interface{}) []byte {
|
||||
h := sha512.Sum384(pass)
|
||||
return h[:]
|
||||
}
|
||||
|
||||
// CryptEncoderSSHA384 encodes a password with SHA256
|
||||
// with the configured salt
|
||||
func CryptEncoderSSHA384(salt string) func(pass []byte, hash interface{}) []byte {
|
||||
return func(pass []byte, hash interface{}) []byte {
|
||||
s := []byte(salt)
|
||||
p := append(pass, s...)
|
||||
h := sha512.Sum384(p)
|
||||
return h[:]
|
||||
}
|
||||
}
|
||||
|
||||
// CryptEncoderSHA512 encodes a password with SHA256
|
||||
func CryptEncoderSHA512(pass []byte, hash interface{}) []byte {
|
||||
h := sha512.Sum512(pass)
|
||||
return h[:]
|
||||
}
|
||||
|
||||
// CryptEncoderSSHA512 encodes a password with SHA256
|
||||
// with the configured salt
|
||||
func CryptEncoderSSHA512(salt string) func(pass []byte, hash interface{}) []byte {
|
||||
return func(pass []byte, hash interface{}) []byte {
|
||||
s := []byte(salt)
|
||||
p := append(pass, s...)
|
||||
h := sha512.Sum512(p)
|
||||
return h[:]
|
||||
}
|
||||
}
|
||||
|
||||
// EOF
|
70
vendor/github.com/mattn/go-sqlite3/sqlite3_go18.go
generated
vendored
70
vendor/github.com/mattn/go-sqlite3/sqlite3_go18.go
generated
vendored
@ -1,70 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build cgo
|
||||
// +build go1.8
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
|
||||
"context"
|
||||
)
|
||||
|
||||
// Ping implement Pinger.
|
||||
func (c *SQLiteConn) Ping(ctx context.Context) error {
|
||||
if c.db == nil {
|
||||
return errors.New("Connection was closed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueryContext implement QueryerContext.
|
||||
func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, nv := range args {
|
||||
list[i] = namedValue(nv)
|
||||
}
|
||||
return c.query(ctx, query, list)
|
||||
}
|
||||
|
||||
// ExecContext implement ExecerContext.
|
||||
func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, nv := range args {
|
||||
list[i] = namedValue(nv)
|
||||
}
|
||||
return c.exec(ctx, query, list)
|
||||
}
|
||||
|
||||
// PrepareContext implement ConnPrepareContext.
|
||||
func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
return c.prepare(ctx, query)
|
||||
}
|
||||
|
||||
// BeginTx implement ConnBeginTx.
|
||||
func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
return c.begin(ctx)
|
||||
}
|
||||
|
||||
// QueryContext implement QueryerContext.
|
||||
func (s *SQLiteStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, nv := range args {
|
||||
list[i] = namedValue(nv)
|
||||
}
|
||||
return s.query(ctx, list)
|
||||
}
|
||||
|
||||
// ExecContext implement ExecerContext.
|
||||
func (s *SQLiteStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
list := make([]namedValue, len(args))
|
||||
for i, nv := range args {
|
||||
list[i] = namedValue(nv)
|
||||
}
|
||||
return s.exec(ctx, list)
|
||||
}
|
17
vendor/github.com/mattn/go-sqlite3/sqlite3_libsqlite3.go
generated
vendored
17
vendor/github.com/mattn/go-sqlite3/sqlite3_libsqlite3.go
generated
vendored
@ -1,17 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build libsqlite3
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DUSE_LIBSQLITE3
|
||||
#cgo linux LDFLAGS: -lsqlite3
|
||||
#cgo darwin LDFLAGS: -L/usr/local/opt/sqlite/lib -lsqlite3
|
||||
#cgo openbsd LDFLAGS: -lsqlite3
|
||||
#cgo solaris LDFLAGS: -lsqlite3
|
||||
*/
|
||||
import "C"
|
70
vendor/github.com/mattn/go-sqlite3/sqlite3_load_extension.go
generated
vendored
70
vendor/github.com/mattn/go-sqlite3/sqlite3_load_extension.go
generated
vendored
@ -1,70 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !sqlite_omit_load_extension
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func (c *SQLiteConn) loadExtensions(extensions []string) error {
|
||||
rv := C.sqlite3_enable_load_extension(c.db, 1)
|
||||
if rv != C.SQLITE_OK {
|
||||
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
|
||||
}
|
||||
|
||||
for _, extension := range extensions {
|
||||
cext := C.CString(extension)
|
||||
defer C.free(unsafe.Pointer(cext))
|
||||
rv = C.sqlite3_load_extension(c.db, cext, nil, nil)
|
||||
if rv != C.SQLITE_OK {
|
||||
C.sqlite3_enable_load_extension(c.db, 0)
|
||||
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
|
||||
}
|
||||
}
|
||||
|
||||
rv = C.sqlite3_enable_load_extension(c.db, 0)
|
||||
if rv != C.SQLITE_OK {
|
||||
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadExtension load the sqlite3 extension.
|
||||
func (c *SQLiteConn) LoadExtension(lib string, entry string) error {
|
||||
rv := C.sqlite3_enable_load_extension(c.db, 1)
|
||||
if rv != C.SQLITE_OK {
|
||||
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
|
||||
}
|
||||
|
||||
clib := C.CString(lib)
|
||||
defer C.free(unsafe.Pointer(clib))
|
||||
centry := C.CString(entry)
|
||||
defer C.free(unsafe.Pointer(centry))
|
||||
|
||||
rv = C.sqlite3_load_extension(c.db, clib, centry, nil)
|
||||
if rv != C.SQLITE_OK {
|
||||
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
|
||||
}
|
||||
|
||||
rv = C.sqlite3_enable_load_extension(c.db, 0)
|
||||
if rv != C.SQLITE_OK {
|
||||
return errors.New(C.GoString(C.sqlite3_errmsg(c.db)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
24
vendor/github.com/mattn/go-sqlite3/sqlite3_load_extension_omit.go
generated
vendored
24
vendor/github.com/mattn/go-sqlite3/sqlite3_load_extension_omit.go
generated
vendored
@ -1,24 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_omit_load_extension
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_OMIT_LOAD_EXTENSION
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
func (c *SQLiteConn) loadExtensions(extensions []string) error {
|
||||
return errors.New("Extensions have been disabled for static builds")
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) LoadExtension(lib string, entry string) error {
|
||||
return errors.New("Extensions have been disabled for static builds")
|
||||
}
|
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_allow_uri_authority.go
generated
vendored
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_allow_uri_authority.go
generated
vendored
@ -1,15 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
// Copyright (C) 2018 G.J.R. Timmer <gjr.timmer@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_allow_uri_authority
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_ALLOW_URI_AUTHORITY
|
||||
#cgo LDFLAGS: -lm
|
||||
*/
|
||||
import "C"
|
16
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_app_armor.go
generated
vendored
16
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_app_armor.go
generated
vendored
@ -1,16 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
// Copyright (C) 2018 G.J.R. Timmer <gjr.timmer@gmail.com>.
|
||||
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !windows
|
||||
// +build sqlite_app_armor
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_API_ARMOR
|
||||
#cgo LDFLAGS: -lm
|
||||
*/
|
||||
import "C"
|
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_foreign_keys.go
generated
vendored
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_foreign_keys.go
generated
vendored
@ -1,15 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
// Copyright (C) 2018 G.J.R. Timmer <gjr.timmer@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_foreign_keys
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_DEFAULT_FOREIGN_KEYS=1
|
||||
#cgo LDFLAGS: -lm
|
||||
*/
|
||||
import "C"
|
14
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_fts5.go
generated
vendored
14
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_fts5.go
generated
vendored
@ -1,14 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_fts5 fts5
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_FTS5
|
||||
#cgo LDFLAGS: -lm
|
||||
*/
|
||||
import "C"
|
17
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_icu.go
generated
vendored
17
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_icu.go
generated
vendored
@ -1,17 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_icu icu
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo LDFLAGS: -licuuc -licui18n
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_ICU
|
||||
#cgo darwin CFLAGS: -I/usr/local/opt/icu4c/include
|
||||
#cgo darwin LDFLAGS: -L/usr/local/opt/icu4c/lib
|
||||
#cgo openbsd LDFLAGS: -lsqlite3
|
||||
*/
|
||||
import "C"
|
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_introspect.go
generated
vendored
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_introspect.go
generated
vendored
@ -1,15 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
// Copyright (C) 2018 G.J.R. Timmer <gjr.timmer@gmail.com>.
|
||||
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_introspect
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_INTROSPECTION_PRAGMAS
|
||||
#cgo LDFLAGS: -lm
|
||||
*/
|
||||
import "C"
|
13
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_json1.go
generated
vendored
13
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_json1.go
generated
vendored
@ -1,13 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_json sqlite_json1 json1
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_JSON1
|
||||
*/
|
||||
import "C"
|
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_secure_delete.go
generated
vendored
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_secure_delete.go
generated
vendored
@ -1,15 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
// Copyright (C) 2018 G.J.R. Timmer <gjr.timmer@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_secure_delete
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_SECURE_DELETE=1
|
||||
#cgo LDFLAGS: -lm
|
||||
*/
|
||||
import "C"
|
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_secure_delete_fast.go
generated
vendored
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_secure_delete_fast.go
generated
vendored
@ -1,15 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
// Copyright (C) 2018 G.J.R. Timmer <gjr.timmer@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_secure_delete_fast
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_SECURE_DELETE=FAST
|
||||
#cgo LDFLAGS: -lm
|
||||
*/
|
||||
import "C"
|
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_stat4.go
generated
vendored
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_stat4.go
generated
vendored
@ -1,15 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
// Copyright (C) 2018 G.J.R. Timmer <gjr.timmer@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_stat4
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_STAT4
|
||||
#cgo LDFLAGS: -lm
|
||||
*/
|
||||
import "C"
|
289
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_userauth.go
generated
vendored
289
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_userauth.go
generated
vendored
@ -1,289 +0,0 @@
|
||||
// Copyright (C) 2018 G.J.R. Timmer <gjr.timmer@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_userauth
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_USER_AUTHENTICATION
|
||||
#cgo LDFLAGS: -lm
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
|
||||
static int
|
||||
_sqlite3_user_authenticate(sqlite3* db, const char* zUsername, const char* aPW, int nPW)
|
||||
{
|
||||
return sqlite3_user_authenticate(db, zUsername, aPW, nPW);
|
||||
}
|
||||
|
||||
static int
|
||||
_sqlite3_user_add(sqlite3* db, const char* zUsername, const char* aPW, int nPW, int isAdmin)
|
||||
{
|
||||
return sqlite3_user_add(db, zUsername, aPW, nPW, isAdmin);
|
||||
}
|
||||
|
||||
static int
|
||||
_sqlite3_user_change(sqlite3* db, const char* zUsername, const char* aPW, int nPW, int isAdmin)
|
||||
{
|
||||
return sqlite3_user_change(db, zUsername, aPW, nPW, isAdmin);
|
||||
}
|
||||
|
||||
static int
|
||||
_sqlite3_user_delete(sqlite3* db, const char* zUsername)
|
||||
{
|
||||
return sqlite3_user_delete(db, zUsername);
|
||||
}
|
||||
|
||||
static int
|
||||
_sqlite3_auth_enabled(sqlite3* db)
|
||||
{
|
||||
int exists = -1;
|
||||
|
||||
sqlite3_stmt *stmt;
|
||||
sqlite3_prepare_v2(db, "select count(type) from sqlite_master WHERE type='table' and name='sqlite_user';", -1, &stmt, NULL);
|
||||
|
||||
while ( sqlite3_step(stmt) == SQLITE_ROW) {
|
||||
exists = sqlite3_column_int(stmt, 0);
|
||||
}
|
||||
|
||||
sqlite3_finalize(stmt);
|
||||
|
||||
return exists;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"errors"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
SQLITE_AUTH = C.SQLITE_AUTH
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnauthorized = errors.New("SQLITE_AUTH: Unauthorized")
|
||||
ErrAdminRequired = errors.New("SQLITE_AUTH: Unauthorized; Admin Privileges Required")
|
||||
)
|
||||
|
||||
// Authenticate will perform an authentication of the provided username
|
||||
// and password against the database.
|
||||
//
|
||||
// If a database contains the SQLITE_USER table, then the
|
||||
// call to Authenticate must be invoked with an
|
||||
// appropriate username and password prior to enable read and write
|
||||
//access to the database.
|
||||
//
|
||||
// Return SQLITE_OK on success or SQLITE_ERROR if the username/password
|
||||
// combination is incorrect or unknown.
|
||||
//
|
||||
// If the SQLITE_USER table is not present in the database file, then
|
||||
// this interface is a harmless no-op returnning SQLITE_OK.
|
||||
func (c *SQLiteConn) Authenticate(username, password string) error {
|
||||
rv := c.authenticate(username, password)
|
||||
switch rv {
|
||||
case C.SQLITE_ERROR, C.SQLITE_AUTH:
|
||||
return ErrUnauthorized
|
||||
case C.SQLITE_OK:
|
||||
return nil
|
||||
default:
|
||||
return c.lastError()
|
||||
}
|
||||
}
|
||||
|
||||
// authenticate provides the actual authentication to SQLite.
|
||||
// This is not exported for usage in Go.
|
||||
// It is however exported for usage within SQL by the user.
|
||||
//
|
||||
// Returns:
|
||||
// C.SQLITE_OK (0)
|
||||
// C.SQLITE_ERROR (1)
|
||||
// C.SQLITE_AUTH (23)
|
||||
func (c *SQLiteConn) authenticate(username, password string) int {
|
||||
// Allocate C Variables
|
||||
cuser := C.CString(username)
|
||||
cpass := C.CString(password)
|
||||
|
||||
// Free C Variables
|
||||
defer func() {
|
||||
C.free(unsafe.Pointer(cuser))
|
||||
C.free(unsafe.Pointer(cpass))
|
||||
}()
|
||||
|
||||
return int(C._sqlite3_user_authenticate(c.db, cuser, cpass, C.int(len(password))))
|
||||
}
|
||||
|
||||
// AuthUserAdd can be used (by an admin user only)
|
||||
// to create a new user. When called on a no-authentication-required
|
||||
// database, this routine converts the database into an authentication-
|
||||
// required database, automatically makes the added user an
|
||||
// administrator, and logs in the current connection as that user.
|
||||
// The AuthUserAdd only works for the "main" database, not
|
||||
// for any ATTACH-ed databases. Any call to AuthUserAdd by a
|
||||
// non-admin user results in an error.
|
||||
func (c *SQLiteConn) AuthUserAdd(username, password string, admin bool) error {
|
||||
isAdmin := 0
|
||||
if admin {
|
||||
isAdmin = 1
|
||||
}
|
||||
|
||||
rv := c.authUserAdd(username, password, isAdmin)
|
||||
switch rv {
|
||||
case C.SQLITE_ERROR, C.SQLITE_AUTH:
|
||||
return ErrAdminRequired
|
||||
case C.SQLITE_OK:
|
||||
return nil
|
||||
default:
|
||||
return c.lastError()
|
||||
}
|
||||
}
|
||||
|
||||
// authUserAdd enables the User Authentication if not enabled.
|
||||
// Otherwise it will add a user.
|
||||
//
|
||||
// When user authentication is already enabled then this function
|
||||
// can only be called by an admin.
|
||||
//
|
||||
// This is not exported for usage in Go.
|
||||
// It is however exported for usage within SQL by the user.
|
||||
//
|
||||
// Returns:
|
||||
// C.SQLITE_OK (0)
|
||||
// C.SQLITE_ERROR (1)
|
||||
// C.SQLITE_AUTH (23)
|
||||
func (c *SQLiteConn) authUserAdd(username, password string, admin int) int {
|
||||
// Allocate C Variables
|
||||
cuser := C.CString(username)
|
||||
cpass := C.CString(password)
|
||||
|
||||
// Free C Variables
|
||||
defer func() {
|
||||
C.free(unsafe.Pointer(cuser))
|
||||
C.free(unsafe.Pointer(cpass))
|
||||
}()
|
||||
|
||||
return int(C._sqlite3_user_add(c.db, cuser, cpass, C.int(len(password)), C.int(admin)))
|
||||
}
|
||||
|
||||
// AuthUserChange can be used to change a users
|
||||
// login credentials or admin privilege. Any user can change their own
|
||||
// login credentials. Only an admin user can change another users login
|
||||
// credentials or admin privilege setting. No user may change their own
|
||||
// admin privilege setting.
|
||||
func (c *SQLiteConn) AuthUserChange(username, password string, admin bool) error {
|
||||
isAdmin := 0
|
||||
if admin {
|
||||
isAdmin = 1
|
||||
}
|
||||
|
||||
rv := c.authUserChange(username, password, isAdmin)
|
||||
switch rv {
|
||||
case C.SQLITE_ERROR, C.SQLITE_AUTH:
|
||||
return ErrAdminRequired
|
||||
case C.SQLITE_OK:
|
||||
return nil
|
||||
default:
|
||||
return c.lastError()
|
||||
}
|
||||
}
|
||||
|
||||
// authUserChange allows to modify a user.
|
||||
// Users can change their own password.
|
||||
//
|
||||
// Only admins can change passwords for other users
|
||||
// and modify the admin flag.
|
||||
//
|
||||
// The admin flag of the current logged in user cannot be changed.
|
||||
// THis ensures that their is always an admin.
|
||||
//
|
||||
// This is not exported for usage in Go.
|
||||
// It is however exported for usage within SQL by the user.
|
||||
//
|
||||
// Returns:
|
||||
// C.SQLITE_OK (0)
|
||||
// C.SQLITE_ERROR (1)
|
||||
// C.SQLITE_AUTH (23)
|
||||
func (c *SQLiteConn) authUserChange(username, password string, admin int) int {
|
||||
// Allocate C Variables
|
||||
cuser := C.CString(username)
|
||||
cpass := C.CString(password)
|
||||
|
||||
// Free C Variables
|
||||
defer func() {
|
||||
C.free(unsafe.Pointer(cuser))
|
||||
C.free(unsafe.Pointer(cpass))
|
||||
}()
|
||||
|
||||
return int(C._sqlite3_user_change(c.db, cuser, cpass, C.int(len(password)), C.int(admin)))
|
||||
}
|
||||
|
||||
// AuthUserDelete can be used (by an admin user only)
|
||||
// to delete a user. The currently logged-in user cannot be deleted,
|
||||
// which guarantees that there is always an admin user and hence that
|
||||
// the database cannot be converted into a no-authentication-required
|
||||
// database.
|
||||
func (c *SQLiteConn) AuthUserDelete(username string) error {
|
||||
rv := c.authUserDelete(username)
|
||||
switch rv {
|
||||
case C.SQLITE_ERROR, C.SQLITE_AUTH:
|
||||
return ErrAdminRequired
|
||||
case C.SQLITE_OK:
|
||||
return nil
|
||||
default:
|
||||
return c.lastError()
|
||||
}
|
||||
}
|
||||
|
||||
// authUserDelete can be used to delete a user.
|
||||
//
|
||||
// This function can only be executed by an admin.
|
||||
//
|
||||
// This is not exported for usage in Go.
|
||||
// It is however exported for usage within SQL by the user.
|
||||
//
|
||||
// Returns:
|
||||
// C.SQLITE_OK (0)
|
||||
// C.SQLITE_ERROR (1)
|
||||
// C.SQLITE_AUTH (23)
|
||||
func (c *SQLiteConn) authUserDelete(username string) int {
|
||||
// Allocate C Variables
|
||||
cuser := C.CString(username)
|
||||
|
||||
// Free C Variables
|
||||
defer func() {
|
||||
C.free(unsafe.Pointer(cuser))
|
||||
}()
|
||||
|
||||
return int(C._sqlite3_user_delete(c.db, cuser))
|
||||
}
|
||||
|
||||
// AuthEnabled checks if the database is protected by user authentication
|
||||
func (c *SQLiteConn) AuthEnabled() (exists bool) {
|
||||
rv := c.authEnabled()
|
||||
if rv == 1 {
|
||||
exists = true
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// authEnabled perform the actual check for user authentication.
|
||||
//
|
||||
// This is not exported for usage in Go.
|
||||
// It is however exported for usage within SQL by the user.
|
||||
//
|
||||
// Returns:
|
||||
// 0 - Disabled
|
||||
// 1 - Enabled
|
||||
func (c *SQLiteConn) authEnabled() int {
|
||||
return int(C._sqlite3_auth_enabled(c.db))
|
||||
}
|
||||
|
||||
// EOF
|
152
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_userauth_omit.go
generated
vendored
152
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_userauth_omit.go
generated
vendored
@ -1,152 +0,0 @@
|
||||
// Copyright (C) 2018 G.J.R. Timmer <gjr.timmer@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !sqlite_userauth
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"C"
|
||||
)
|
||||
|
||||
// Authenticate will perform an authentication of the provided username
|
||||
// and password against the database.
|
||||
//
|
||||
// If a database contains the SQLITE_USER table, then the
|
||||
// call to Authenticate must be invoked with an
|
||||
// appropriate username and password prior to enable read and write
|
||||
//access to the database.
|
||||
//
|
||||
// Return SQLITE_OK on success or SQLITE_ERROR if the username/password
|
||||
// combination is incorrect or unknown.
|
||||
//
|
||||
// If the SQLITE_USER table is not present in the database file, then
|
||||
// this interface is a harmless no-op returnning SQLITE_OK.
|
||||
func (c *SQLiteConn) Authenticate(username, password string) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
// authenticate provides the actual authentication to SQLite.
|
||||
// This is not exported for usage in Go.
|
||||
// It is however exported for usage within SQL by the user.
|
||||
//
|
||||
// Returns:
|
||||
// C.SQLITE_OK (0)
|
||||
// C.SQLITE_ERROR (1)
|
||||
// C.SQLITE_AUTH (23)
|
||||
func (c *SQLiteConn) authenticate(username, password string) int {
|
||||
// NOOP
|
||||
return 0
|
||||
}
|
||||
|
||||
// AuthUserAdd can be used (by an admin user only)
|
||||
// to create a new user. When called on a no-authentication-required
|
||||
// database, this routine converts the database into an authentication-
|
||||
// required database, automatically makes the added user an
|
||||
// administrator, and logs in the current connection as that user.
|
||||
// The AuthUserAdd only works for the "main" database, not
|
||||
// for any ATTACH-ed databases. Any call to AuthUserAdd by a
|
||||
// non-admin user results in an error.
|
||||
func (c *SQLiteConn) AuthUserAdd(username, password string, admin bool) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
// authUserAdd enables the User Authentication if not enabled.
|
||||
// Otherwise it will add a user.
|
||||
//
|
||||
// When user authentication is already enabled then this function
|
||||
// can only be called by an admin.
|
||||
//
|
||||
// This is not exported for usage in Go.
|
||||
// It is however exported for usage within SQL by the user.
|
||||
//
|
||||
// Returns:
|
||||
// C.SQLITE_OK (0)
|
||||
// C.SQLITE_ERROR (1)
|
||||
// C.SQLITE_AUTH (23)
|
||||
func (c *SQLiteConn) authUserAdd(username, password string, admin int) int {
|
||||
// NOOP
|
||||
return 0
|
||||
}
|
||||
|
||||
// AuthUserChange can be used to change a users
|
||||
// login credentials or admin privilege. Any user can change their own
|
||||
// login credentials. Only an admin user can change another users login
|
||||
// credentials or admin privilege setting. No user may change their own
|
||||
// admin privilege setting.
|
||||
func (c *SQLiteConn) AuthUserChange(username, password string, admin bool) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
// authUserChange allows to modify a user.
|
||||
// Users can change their own password.
|
||||
//
|
||||
// Only admins can change passwords for other users
|
||||
// and modify the admin flag.
|
||||
//
|
||||
// The admin flag of the current logged in user cannot be changed.
|
||||
// THis ensures that their is always an admin.
|
||||
//
|
||||
// This is not exported for usage in Go.
|
||||
// It is however exported for usage within SQL by the user.
|
||||
//
|
||||
// Returns:
|
||||
// C.SQLITE_OK (0)
|
||||
// C.SQLITE_ERROR (1)
|
||||
// C.SQLITE_AUTH (23)
|
||||
func (c *SQLiteConn) authUserChange(username, password string, admin int) int {
|
||||
// NOOP
|
||||
return 0
|
||||
}
|
||||
|
||||
// AuthUserDelete can be used (by an admin user only)
|
||||
// to delete a user. The currently logged-in user cannot be deleted,
|
||||
// which guarantees that there is always an admin user and hence that
|
||||
// the database cannot be converted into a no-authentication-required
|
||||
// database.
|
||||
func (c *SQLiteConn) AuthUserDelete(username string) error {
|
||||
// NOOP
|
||||
return nil
|
||||
}
|
||||
|
||||
// authUserDelete can be used to delete a user.
|
||||
//
|
||||
// This function can only be executed by an admin.
|
||||
//
|
||||
// This is not exported for usage in Go.
|
||||
// It is however exported for usage within SQL by the user.
|
||||
//
|
||||
// Returns:
|
||||
// C.SQLITE_OK (0)
|
||||
// C.SQLITE_ERROR (1)
|
||||
// C.SQLITE_AUTH (23)
|
||||
func (c *SQLiteConn) authUserDelete(username string) int {
|
||||
// NOOP
|
||||
return 0
|
||||
}
|
||||
|
||||
// AuthEnabled checks if the database is protected by user authentication
|
||||
func (c *SQLiteConn) AuthEnabled() (exists bool) {
|
||||
// NOOP
|
||||
return false
|
||||
}
|
||||
|
||||
// authEnabled perform the actual check for user authentication.
|
||||
//
|
||||
// This is not exported for usage in Go.
|
||||
// It is however exported for usage within SQL by the user.
|
||||
//
|
||||
// Returns:
|
||||
// 0 - Disabled
|
||||
// 1 - Enabled
|
||||
func (c *SQLiteConn) authEnabled() int {
|
||||
// NOOP
|
||||
return 0
|
||||
}
|
||||
|
||||
// EOF
|
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_vacuum_full.go
generated
vendored
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_vacuum_full.go
generated
vendored
@ -1,15 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
// Copyright (C) 2018 G.J.R. Timmer <gjr.timmer@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_vacuum_full
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_DEFAULT_AUTOVACUUM=1
|
||||
#cgo LDFLAGS: -lm
|
||||
*/
|
||||
import "C"
|
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_vacuum_incr.go
generated
vendored
15
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_vacuum_incr.go
generated
vendored
@ -1,15 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
// Copyright (C) 2018 G.J.R. Timmer <gjr.timmer@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_vacuum_incr
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -DSQLITE_DEFAULT_AUTOVACUUM=2
|
||||
#cgo LDFLAGS: -lm
|
||||
*/
|
||||
import "C"
|
650
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_vtable.go
generated
vendored
650
vendor/github.com/mattn/go-sqlite3/sqlite3_opt_vtable.go
generated
vendored
@ -1,650 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_vtable vtable
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -std=gnu99
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE
|
||||
#cgo CFLAGS: -DSQLITE_THREADSAFE
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3_PARENTHESIS
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_FTS4_UNICODE61
|
||||
#cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15
|
||||
#cgo CFLAGS: -DSQLITE_ENABLE_COLUMN_METADATA=1
|
||||
#cgo CFLAGS: -Wno-deprecated-declarations
|
||||
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include <memory.h>
|
||||
|
||||
static inline char *_sqlite3_mprintf(char *zFormat, char *arg) {
|
||||
return sqlite3_mprintf(zFormat, arg);
|
||||
}
|
||||
|
||||
typedef struct goVTab goVTab;
|
||||
|
||||
struct goVTab {
|
||||
sqlite3_vtab base;
|
||||
void *vTab;
|
||||
};
|
||||
|
||||
uintptr_t goMInit(void *db, void *pAux, int argc, char **argv, char **pzErr, int isCreate);
|
||||
|
||||
static int cXInit(sqlite3 *db, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char **pzErr, int isCreate) {
|
||||
void *vTab = (void *)goMInit(db, pAux, argc, (char**)argv, pzErr, isCreate);
|
||||
if (!vTab || *pzErr) {
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
goVTab *pvTab = (goVTab *)sqlite3_malloc(sizeof(goVTab));
|
||||
if (!pvTab) {
|
||||
*pzErr = sqlite3_mprintf("%s", "Out of memory");
|
||||
return SQLITE_NOMEM;
|
||||
}
|
||||
memset(pvTab, 0, sizeof(goVTab));
|
||||
pvTab->vTab = vTab;
|
||||
|
||||
*ppVTab = (sqlite3_vtab *)pvTab;
|
||||
*pzErr = 0;
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
static inline int cXCreate(sqlite3 *db, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char **pzErr) {
|
||||
return cXInit(db, pAux, argc, argv, ppVTab, pzErr, 1);
|
||||
}
|
||||
static inline int cXConnect(sqlite3 *db, void *pAux, int argc, const char *const*argv, sqlite3_vtab **ppVTab, char **pzErr) {
|
||||
return cXInit(db, pAux, argc, argv, ppVTab, pzErr, 0);
|
||||
}
|
||||
|
||||
char* goVBestIndex(void *pVTab, void *icp);
|
||||
|
||||
static inline int cXBestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *info) {
|
||||
char *pzErr = goVBestIndex(((goVTab*)pVTab)->vTab, info);
|
||||
if (pzErr) {
|
||||
if (pVTab->zErrMsg)
|
||||
sqlite3_free(pVTab->zErrMsg);
|
||||
pVTab->zErrMsg = pzErr;
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
char* goVRelease(void *pVTab, int isDestroy);
|
||||
|
||||
static int cXRelease(sqlite3_vtab *pVTab, int isDestroy) {
|
||||
char *pzErr = goVRelease(((goVTab*)pVTab)->vTab, isDestroy);
|
||||
if (pzErr) {
|
||||
if (pVTab->zErrMsg)
|
||||
sqlite3_free(pVTab->zErrMsg);
|
||||
pVTab->zErrMsg = pzErr;
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
if (pVTab->zErrMsg)
|
||||
sqlite3_free(pVTab->zErrMsg);
|
||||
sqlite3_free(pVTab);
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
static inline int cXDisconnect(sqlite3_vtab *pVTab) {
|
||||
return cXRelease(pVTab, 0);
|
||||
}
|
||||
static inline int cXDestroy(sqlite3_vtab *pVTab) {
|
||||
return cXRelease(pVTab, 1);
|
||||
}
|
||||
|
||||
typedef struct goVTabCursor goVTabCursor;
|
||||
|
||||
struct goVTabCursor {
|
||||
sqlite3_vtab_cursor base;
|
||||
void *vTabCursor;
|
||||
};
|
||||
|
||||
uintptr_t goVOpen(void *pVTab, char **pzErr);
|
||||
|
||||
static int cXOpen(sqlite3_vtab *pVTab, sqlite3_vtab_cursor **ppCursor) {
|
||||
void *vTabCursor = (void *)goVOpen(((goVTab*)pVTab)->vTab, &(pVTab->zErrMsg));
|
||||
goVTabCursor *pCursor = (goVTabCursor *)sqlite3_malloc(sizeof(goVTabCursor));
|
||||
if (!pCursor) {
|
||||
return SQLITE_NOMEM;
|
||||
}
|
||||
memset(pCursor, 0, sizeof(goVTabCursor));
|
||||
pCursor->vTabCursor = vTabCursor;
|
||||
*ppCursor = (sqlite3_vtab_cursor *)pCursor;
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
static int setErrMsg(sqlite3_vtab_cursor *pCursor, char *pzErr) {
|
||||
if (pCursor->pVtab->zErrMsg)
|
||||
sqlite3_free(pCursor->pVtab->zErrMsg);
|
||||
pCursor->pVtab->zErrMsg = pzErr;
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
|
||||
char* goVClose(void *pCursor);
|
||||
|
||||
static int cXClose(sqlite3_vtab_cursor *pCursor) {
|
||||
char *pzErr = goVClose(((goVTabCursor*)pCursor)->vTabCursor);
|
||||
if (pzErr) {
|
||||
return setErrMsg(pCursor, pzErr);
|
||||
}
|
||||
sqlite3_free(pCursor);
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
char* goVFilter(void *pCursor, int idxNum, char* idxName, int argc, sqlite3_value **argv);
|
||||
|
||||
static int cXFilter(sqlite3_vtab_cursor *pCursor, int idxNum, const char *idxStr, int argc, sqlite3_value **argv) {
|
||||
char *pzErr = goVFilter(((goVTabCursor*)pCursor)->vTabCursor, idxNum, (char*)idxStr, argc, argv);
|
||||
if (pzErr) {
|
||||
return setErrMsg(pCursor, pzErr);
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
char* goVNext(void *pCursor);
|
||||
|
||||
static int cXNext(sqlite3_vtab_cursor *pCursor) {
|
||||
char *pzErr = goVNext(((goVTabCursor*)pCursor)->vTabCursor);
|
||||
if (pzErr) {
|
||||
return setErrMsg(pCursor, pzErr);
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
int goVEof(void *pCursor);
|
||||
|
||||
static inline int cXEof(sqlite3_vtab_cursor *pCursor) {
|
||||
return goVEof(((goVTabCursor*)pCursor)->vTabCursor);
|
||||
}
|
||||
|
||||
char* goVColumn(void *pCursor, void *cp, int col);
|
||||
|
||||
static int cXColumn(sqlite3_vtab_cursor *pCursor, sqlite3_context *ctx, int i) {
|
||||
char *pzErr = goVColumn(((goVTabCursor*)pCursor)->vTabCursor, ctx, i);
|
||||
if (pzErr) {
|
||||
return setErrMsg(pCursor, pzErr);
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
char* goVRowid(void *pCursor, sqlite3_int64 *pRowid);
|
||||
|
||||
static int cXRowid(sqlite3_vtab_cursor *pCursor, sqlite3_int64 *pRowid) {
|
||||
char *pzErr = goVRowid(((goVTabCursor*)pCursor)->vTabCursor, pRowid);
|
||||
if (pzErr) {
|
||||
return setErrMsg(pCursor, pzErr);
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
char* goVUpdate(void *pVTab, int argc, sqlite3_value **argv, sqlite3_int64 *pRowid);
|
||||
|
||||
static int cXUpdate(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv, sqlite3_int64 *pRowid) {
|
||||
char *pzErr = goVUpdate(((goVTab*)pVTab)->vTab, argc, argv, pRowid);
|
||||
if (pzErr) {
|
||||
if (pVTab->zErrMsg)
|
||||
sqlite3_free(pVTab->zErrMsg);
|
||||
pVTab->zErrMsg = pzErr;
|
||||
return SQLITE_ERROR;
|
||||
}
|
||||
return SQLITE_OK;
|
||||
}
|
||||
|
||||
static sqlite3_module goModule = {
|
||||
0, // iVersion
|
||||
cXCreate, // xCreate - create a table
|
||||
cXConnect, // xConnect - connect to an existing table
|
||||
cXBestIndex, // xBestIndex - Determine search strategy
|
||||
cXDisconnect, // xDisconnect - Disconnect from a table
|
||||
cXDestroy, // xDestroy - Drop a table
|
||||
cXOpen, // xOpen - open a cursor
|
||||
cXClose, // xClose - close a cursor
|
||||
cXFilter, // xFilter - configure scan constraints
|
||||
cXNext, // xNext - advance a cursor
|
||||
cXEof, // xEof
|
||||
cXColumn, // xColumn - read data
|
||||
cXRowid, // xRowid - read data
|
||||
cXUpdate, // xUpdate - write data
|
||||
// Not implemented
|
||||
0, // xBegin - begin transaction
|
||||
0, // xSync - sync transaction
|
||||
0, // xCommit - commit transaction
|
||||
0, // xRollback - rollback transaction
|
||||
0, // xFindFunction - function overloading
|
||||
0, // xRename - rename the table
|
||||
0, // xSavepoint
|
||||
0, // xRelease
|
||||
0 // xRollbackTo
|
||||
};
|
||||
|
||||
void goMDestroy(void*);
|
||||
|
||||
static int _sqlite3_create_module(sqlite3 *db, const char *zName, uintptr_t pClientData) {
|
||||
return sqlite3_create_module_v2(db, zName, &goModule, (void*) pClientData, goMDestroy);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type sqliteModule struct {
|
||||
c *SQLiteConn
|
||||
name string
|
||||
module Module
|
||||
}
|
||||
|
||||
type sqliteVTab struct {
|
||||
module *sqliteModule
|
||||
vTab VTab
|
||||
}
|
||||
|
||||
type sqliteVTabCursor struct {
|
||||
vTab *sqliteVTab
|
||||
vTabCursor VTabCursor
|
||||
}
|
||||
|
||||
// Op is type of operations.
|
||||
type Op uint8
|
||||
|
||||
// Op mean identity of operations.
|
||||
const (
|
||||
OpEQ Op = 2
|
||||
OpGT = 4
|
||||
OpLE = 8
|
||||
OpLT = 16
|
||||
OpGE = 32
|
||||
OpMATCH = 64
|
||||
OpLIKE = 65 /* 3.10.0 and later only */
|
||||
OpGLOB = 66 /* 3.10.0 and later only */
|
||||
OpREGEXP = 67 /* 3.10.0 and later only */
|
||||
OpScanUnique = 1 /* Scan visits at most 1 row */
|
||||
)
|
||||
|
||||
// InfoConstraint give information of constraint.
|
||||
type InfoConstraint struct {
|
||||
Column int
|
||||
Op Op
|
||||
Usable bool
|
||||
}
|
||||
|
||||
// InfoOrderBy give information of order-by.
|
||||
type InfoOrderBy struct {
|
||||
Column int
|
||||
Desc bool
|
||||
}
|
||||
|
||||
func constraints(info *C.sqlite3_index_info) []InfoConstraint {
|
||||
l := info.nConstraint
|
||||
slice := (*[1 << 30]C.struct_sqlite3_index_constraint)(unsafe.Pointer(info.aConstraint))[:l:l]
|
||||
|
||||
cst := make([]InfoConstraint, 0, l)
|
||||
for _, c := range slice {
|
||||
var usable bool
|
||||
if c.usable > 0 {
|
||||
usable = true
|
||||
}
|
||||
cst = append(cst, InfoConstraint{
|
||||
Column: int(c.iColumn),
|
||||
Op: Op(c.op),
|
||||
Usable: usable,
|
||||
})
|
||||
}
|
||||
return cst
|
||||
}
|
||||
|
||||
func orderBys(info *C.sqlite3_index_info) []InfoOrderBy {
|
||||
l := info.nOrderBy
|
||||
slice := (*[1 << 30]C.struct_sqlite3_index_orderby)(unsafe.Pointer(info.aOrderBy))[:l:l]
|
||||
|
||||
ob := make([]InfoOrderBy, 0, l)
|
||||
for _, c := range slice {
|
||||
var desc bool
|
||||
if c.desc > 0 {
|
||||
desc = true
|
||||
}
|
||||
ob = append(ob, InfoOrderBy{
|
||||
Column: int(c.iColumn),
|
||||
Desc: desc,
|
||||
})
|
||||
}
|
||||
return ob
|
||||
}
|
||||
|
||||
// IndexResult is a Go struct representation of what eventually ends up in the
|
||||
// output fields for `sqlite3_index_info`
|
||||
// See: https://www.sqlite.org/c3ref/index_info.html
|
||||
type IndexResult struct {
|
||||
Used []bool // aConstraintUsage
|
||||
IdxNum int
|
||||
IdxStr string
|
||||
AlreadyOrdered bool // orderByConsumed
|
||||
EstimatedCost float64
|
||||
EstimatedRows float64
|
||||
}
|
||||
|
||||
// mPrintf is a utility wrapper around sqlite3_mprintf
|
||||
func mPrintf(format, arg string) *C.char {
|
||||
cf := C.CString(format)
|
||||
defer C.free(unsafe.Pointer(cf))
|
||||
ca := C.CString(arg)
|
||||
defer C.free(unsafe.Pointer(ca))
|
||||
return C._sqlite3_mprintf(cf, ca)
|
||||
}
|
||||
|
||||
//export goMInit
|
||||
func goMInit(db, pClientData unsafe.Pointer, argc C.int, argv **C.char, pzErr **C.char, isCreate C.int) C.uintptr_t {
|
||||
m := lookupHandle(uintptr(pClientData)).(*sqliteModule)
|
||||
if m.c.db != (*C.sqlite3)(db) {
|
||||
*pzErr = mPrintf("%s", "Inconsistent db handles")
|
||||
return 0
|
||||
}
|
||||
args := make([]string, argc)
|
||||
var A []*C.char
|
||||
slice := reflect.SliceHeader{Data: uintptr(unsafe.Pointer(argv)), Len: int(argc), Cap: int(argc)}
|
||||
a := reflect.NewAt(reflect.TypeOf(A), unsafe.Pointer(&slice)).Elem().Interface()
|
||||
for i, s := range a.([]*C.char) {
|
||||
args[i] = C.GoString(s)
|
||||
}
|
||||
var vTab VTab
|
||||
var err error
|
||||
if isCreate == 1 {
|
||||
vTab, err = m.module.Create(m.c, args)
|
||||
} else {
|
||||
vTab, err = m.module.Connect(m.c, args)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
*pzErr = mPrintf("%s", err.Error())
|
||||
return 0
|
||||
}
|
||||
vt := sqliteVTab{m, vTab}
|
||||
*pzErr = nil
|
||||
return C.uintptr_t(newHandle(m.c, &vt))
|
||||
}
|
||||
|
||||
//export goVRelease
|
||||
func goVRelease(pVTab unsafe.Pointer, isDestroy C.int) *C.char {
|
||||
vt := lookupHandle(uintptr(pVTab)).(*sqliteVTab)
|
||||
var err error
|
||||
if isDestroy == 1 {
|
||||
err = vt.vTab.Destroy()
|
||||
} else {
|
||||
err = vt.vTab.Disconnect()
|
||||
}
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goVOpen
|
||||
func goVOpen(pVTab unsafe.Pointer, pzErr **C.char) C.uintptr_t {
|
||||
vt := lookupHandle(uintptr(pVTab)).(*sqliteVTab)
|
||||
vTabCursor, err := vt.vTab.Open()
|
||||
if err != nil {
|
||||
*pzErr = mPrintf("%s", err.Error())
|
||||
return 0
|
||||
}
|
||||
vtc := sqliteVTabCursor{vt, vTabCursor}
|
||||
*pzErr = nil
|
||||
return C.uintptr_t(newHandle(vt.module.c, &vtc))
|
||||
}
|
||||
|
||||
//export goVBestIndex
|
||||
func goVBestIndex(pVTab unsafe.Pointer, icp unsafe.Pointer) *C.char {
|
||||
vt := lookupHandle(uintptr(pVTab)).(*sqliteVTab)
|
||||
info := (*C.sqlite3_index_info)(icp)
|
||||
csts := constraints(info)
|
||||
res, err := vt.vTab.BestIndex(csts, orderBys(info))
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
if len(res.Used) != len(csts) {
|
||||
return mPrintf("Result.Used != expected value", "")
|
||||
}
|
||||
|
||||
// Get a pointer to constraint_usage struct so we can update in place.
|
||||
l := info.nConstraint
|
||||
s := (*[1 << 30]C.struct_sqlite3_index_constraint_usage)(unsafe.Pointer(info.aConstraintUsage))[:l:l]
|
||||
index := 1
|
||||
for i := C.int(0); i < info.nConstraint; i++ {
|
||||
if res.Used[i] {
|
||||
s[i].argvIndex = C.int(index)
|
||||
s[i].omit = C.uchar(1)
|
||||
index++
|
||||
}
|
||||
}
|
||||
|
||||
info.idxNum = C.int(res.IdxNum)
|
||||
idxStr := C.CString(res.IdxStr)
|
||||
defer C.free(unsafe.Pointer(idxStr))
|
||||
info.idxStr = idxStr
|
||||
info.needToFreeIdxStr = C.int(0)
|
||||
if res.AlreadyOrdered {
|
||||
info.orderByConsumed = C.int(1)
|
||||
}
|
||||
info.estimatedCost = C.double(res.EstimatedCost)
|
||||
info.estimatedRows = C.sqlite3_int64(res.EstimatedRows)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goVClose
|
||||
func goVClose(pCursor unsafe.Pointer) *C.char {
|
||||
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
|
||||
err := vtc.vTabCursor.Close()
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goMDestroy
|
||||
func goMDestroy(pClientData unsafe.Pointer) {
|
||||
m := lookupHandle(uintptr(pClientData)).(*sqliteModule)
|
||||
m.module.DestroyModule()
|
||||
}
|
||||
|
||||
//export goVFilter
|
||||
func goVFilter(pCursor unsafe.Pointer, idxNum C.int, idxName *C.char, argc C.int, argv **C.sqlite3_value) *C.char {
|
||||
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
|
||||
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
|
||||
vals := make([]interface{}, 0, argc)
|
||||
for _, v := range args {
|
||||
conv, err := callbackArgGeneric(v)
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
vals = append(vals, conv.Interface())
|
||||
}
|
||||
err := vtc.vTabCursor.Filter(int(idxNum), C.GoString(idxName), vals)
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goVNext
|
||||
func goVNext(pCursor unsafe.Pointer) *C.char {
|
||||
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
|
||||
err := vtc.vTabCursor.Next()
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goVEof
|
||||
func goVEof(pCursor unsafe.Pointer) C.int {
|
||||
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
|
||||
err := vtc.vTabCursor.EOF()
|
||||
if err {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
//export goVColumn
|
||||
func goVColumn(pCursor, cp unsafe.Pointer, col C.int) *C.char {
|
||||
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
|
||||
c := (*SQLiteContext)(cp)
|
||||
err := vtc.vTabCursor.Column(c, int(col))
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goVRowid
|
||||
func goVRowid(pCursor unsafe.Pointer, pRowid *C.sqlite3_int64) *C.char {
|
||||
vtc := lookupHandle(uintptr(pCursor)).(*sqliteVTabCursor)
|
||||
rowid, err := vtc.vTabCursor.Rowid()
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
*pRowid = C.sqlite3_int64(rowid)
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goVUpdate
|
||||
func goVUpdate(pVTab unsafe.Pointer, argc C.int, argv **C.sqlite3_value, pRowid *C.sqlite3_int64) *C.char {
|
||||
vt := lookupHandle(uintptr(pVTab)).(*sqliteVTab)
|
||||
|
||||
var tname string
|
||||
if n, ok := vt.vTab.(interface {
|
||||
TableName() string
|
||||
}); ok {
|
||||
tname = n.TableName() + " "
|
||||
}
|
||||
|
||||
err := fmt.Errorf("virtual %s table %sis read-only", vt.module.name, tname)
|
||||
if v, ok := vt.vTab.(VTabUpdater); ok {
|
||||
// convert argv
|
||||
args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
|
||||
vals := make([]interface{}, 0, argc)
|
||||
for _, v := range args {
|
||||
conv, err := callbackArgGeneric(v)
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
|
||||
// work around for SQLITE_NULL
|
||||
x := conv.Interface()
|
||||
if z, ok := x.([]byte); ok && z == nil {
|
||||
x = nil
|
||||
}
|
||||
|
||||
vals = append(vals, x)
|
||||
}
|
||||
|
||||
switch {
|
||||
case argc == 1:
|
||||
err = v.Delete(vals[0])
|
||||
|
||||
case argc > 1 && vals[0] == nil:
|
||||
var id int64
|
||||
id, err = v.Insert(vals[1], vals[2:])
|
||||
if err == nil {
|
||||
*pRowid = C.sqlite3_int64(id)
|
||||
}
|
||||
|
||||
case argc > 1:
|
||||
err = v.Update(vals[1], vals[2:])
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return mPrintf("%s", err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Module is a "virtual table module", it defines the implementation of a
|
||||
// virtual tables. See: http://sqlite.org/c3ref/module.html
|
||||
type Module interface {
|
||||
// http://sqlite.org/vtab.html#xcreate
|
||||
Create(c *SQLiteConn, args []string) (VTab, error)
|
||||
// http://sqlite.org/vtab.html#xconnect
|
||||
Connect(c *SQLiteConn, args []string) (VTab, error)
|
||||
// http://sqlite.org/c3ref/create_module.html
|
||||
DestroyModule()
|
||||
}
|
||||
|
||||
// VTab describes a particular instance of the virtual table.
|
||||
// See: http://sqlite.org/c3ref/vtab.html
|
||||
type VTab interface {
|
||||
// http://sqlite.org/vtab.html#xbestindex
|
||||
BestIndex([]InfoConstraint, []InfoOrderBy) (*IndexResult, error)
|
||||
// http://sqlite.org/vtab.html#xdisconnect
|
||||
Disconnect() error
|
||||
// http://sqlite.org/vtab.html#sqlite3_module.xDestroy
|
||||
Destroy() error
|
||||
// http://sqlite.org/vtab.html#xopen
|
||||
Open() (VTabCursor, error)
|
||||
}
|
||||
|
||||
// VTabUpdater is a type that allows a VTab to be inserted, updated, or
|
||||
// deleted.
|
||||
// See: https://sqlite.org/vtab.html#xupdate
|
||||
type VTabUpdater interface {
|
||||
Delete(interface{}) error
|
||||
Insert(interface{}, []interface{}) (int64, error)
|
||||
Update(interface{}, []interface{}) error
|
||||
}
|
||||
|
||||
// VTabCursor describes cursors that point into the virtual table and are used
|
||||
// to loop through the virtual table. See: http://sqlite.org/c3ref/vtab_cursor.html
|
||||
type VTabCursor interface {
|
||||
// http://sqlite.org/vtab.html#xclose
|
||||
Close() error
|
||||
// http://sqlite.org/vtab.html#xfilter
|
||||
Filter(idxNum int, idxStr string, vals []interface{}) error
|
||||
// http://sqlite.org/vtab.html#xnext
|
||||
Next() error
|
||||
// http://sqlite.org/vtab.html#xeof
|
||||
EOF() bool
|
||||
// http://sqlite.org/vtab.html#xcolumn
|
||||
Column(c *SQLiteContext, col int) error
|
||||
// http://sqlite.org/vtab.html#xrowid
|
||||
Rowid() (int64, error)
|
||||
}
|
||||
|
||||
// DeclareVTab declares the Schema of a virtual table.
|
||||
// See: http://sqlite.org/c3ref/declare_vtab.html
|
||||
func (c *SQLiteConn) DeclareVTab(sql string) error {
|
||||
zSQL := C.CString(sql)
|
||||
defer C.free(unsafe.Pointer(zSQL))
|
||||
rv := C.sqlite3_declare_vtab(c.db, zSQL)
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateModule registers a virtual table implementation.
|
||||
// See: http://sqlite.org/c3ref/create_module.html
|
||||
func (c *SQLiteConn) CreateModule(moduleName string, module Module) error {
|
||||
mname := C.CString(moduleName)
|
||||
defer C.free(unsafe.Pointer(mname))
|
||||
udm := sqliteModule{c, moduleName, module}
|
||||
rv := C._sqlite3_create_module(c.db, mname, C.uintptr_t(newHandle(c, &udm)))
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
14
vendor/github.com/mattn/go-sqlite3/sqlite3_other.go
generated
vendored
14
vendor/github.com/mattn/go-sqlite3/sqlite3_other.go
generated
vendored
@ -1,14 +0,0 @@
|
||||
// Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !windows
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I.
|
||||
#cgo linux LDFLAGS: -ldl
|
||||
*/
|
||||
import "C"
|
14
vendor/github.com/mattn/go-sqlite3/sqlite3_solaris.go
generated
vendored
14
vendor/github.com/mattn/go-sqlite3/sqlite3_solaris.go
generated
vendored
@ -1,14 +0,0 @@
|
||||
// Copyright (C) 2018 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build solaris
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -D__EXTENSIONS__=1
|
||||
#cgo LDFLAGS: -lc
|
||||
*/
|
||||
import "C"
|
288
vendor/github.com/mattn/go-sqlite3/sqlite3_trace.go
generated
vendored
288
vendor/github.com/mattn/go-sqlite3/sqlite3_trace.go
generated
vendored
@ -1,288 +0,0 @@
|
||||
// Copyright (C) 2016 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
|
||||
//
|
||||
// Use of this source code is governed by an MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build sqlite_trace trace
|
||||
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
#include <stdlib.h>
|
||||
|
||||
int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x);
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Trace... constants identify the possible events causing callback invocation.
|
||||
// Values are same as the corresponding SQLite Trace Event Codes.
|
||||
const (
|
||||
TraceStmt = uint32(C.SQLITE_TRACE_STMT)
|
||||
TraceProfile = uint32(C.SQLITE_TRACE_PROFILE)
|
||||
TraceRow = uint32(C.SQLITE_TRACE_ROW)
|
||||
TraceClose = uint32(C.SQLITE_TRACE_CLOSE)
|
||||
)
|
||||
|
||||
type TraceInfo struct {
|
||||
// Pack together the shorter fields, to keep the struct smaller.
|
||||
// On a 64-bit machine there would be padding
|
||||
// between EventCode and ConnHandle; having AutoCommit here is "free":
|
||||
EventCode uint32
|
||||
AutoCommit bool
|
||||
ConnHandle uintptr
|
||||
|
||||
// Usually filled, unless EventCode = TraceClose = SQLITE_TRACE_CLOSE:
|
||||
// identifier for a prepared statement:
|
||||
StmtHandle uintptr
|
||||
|
||||
// Two strings filled when EventCode = TraceStmt = SQLITE_TRACE_STMT:
|
||||
// (1) either the unexpanded SQL text of the prepared statement, or
|
||||
// an SQL comment that indicates the invocation of a trigger;
|
||||
// (2) expanded SQL, if requested and if (1) is not an SQL comment.
|
||||
StmtOrTrigger string
|
||||
ExpandedSQL string // only if requested (TraceConfig.WantExpandedSQL = true)
|
||||
|
||||
// filled when EventCode = TraceProfile = SQLITE_TRACE_PROFILE:
|
||||
// estimated number of nanoseconds that the prepared statement took to run:
|
||||
RunTimeNanosec int64
|
||||
|
||||
DBError Error
|
||||
}
|
||||
|
||||
// TraceUserCallback gives the signature for a trace function
|
||||
// provided by the user (Go application programmer).
|
||||
// SQLite 3.14 documentation (as of September 2, 2016)
|
||||
// for SQL Trace Hook = sqlite3_trace_v2():
|
||||
// The integer return value from the callback is currently ignored,
|
||||
// though this may change in future releases. Callback implementations
|
||||
// should return zero to ensure future compatibility.
|
||||
type TraceUserCallback func(TraceInfo) int
|
||||
|
||||
type TraceConfig struct {
|
||||
Callback TraceUserCallback
|
||||
EventMask uint32
|
||||
WantExpandedSQL bool
|
||||
}
|
||||
|
||||
func fillDBError(dbErr *Error, db *C.sqlite3) {
|
||||
// See SQLiteConn.lastError(), in file 'sqlite3.go' at the time of writing (Sept 5, 2016)
|
||||
dbErr.Code = ErrNo(C.sqlite3_errcode(db))
|
||||
dbErr.ExtendedCode = ErrNoExtended(C.sqlite3_extended_errcode(db))
|
||||
dbErr.err = C.GoString(C.sqlite3_errmsg(db))
|
||||
}
|
||||
|
||||
func fillExpandedSQL(info *TraceInfo, db *C.sqlite3, pStmt unsafe.Pointer) {
|
||||
if pStmt == nil {
|
||||
panic("No SQLite statement pointer in P arg of trace_v2 callback")
|
||||
}
|
||||
|
||||
expSQLiteCStr := C.sqlite3_expanded_sql((*C.sqlite3_stmt)(pStmt))
|
||||
if expSQLiteCStr == nil {
|
||||
fillDBError(&info.DBError, db)
|
||||
return
|
||||
}
|
||||
info.ExpandedSQL = C.GoString(expSQLiteCStr)
|
||||
}
|
||||
|
||||
//export traceCallbackTrampoline
|
||||
func traceCallbackTrampoline(
|
||||
traceEventCode C.uint,
|
||||
// Parameter named 'C' in SQLite docs = Context given at registration:
|
||||
ctx unsafe.Pointer,
|
||||
// Parameter named 'P' in SQLite docs (Primary event data?):
|
||||
p unsafe.Pointer,
|
||||
// Parameter named 'X' in SQLite docs (eXtra event data?):
|
||||
xValue unsafe.Pointer) C.int {
|
||||
|
||||
eventCode := uint32(traceEventCode)
|
||||
|
||||
if ctx == nil {
|
||||
panic(fmt.Sprintf("No context (ev 0x%x)", traceEventCode))
|
||||
}
|
||||
|
||||
contextDB := (*C.sqlite3)(ctx)
|
||||
connHandle := uintptr(ctx)
|
||||
|
||||
var traceConf TraceConfig
|
||||
var found bool
|
||||
if eventCode == TraceClose {
|
||||
// clean up traceMap: 'pop' means get and delete
|
||||
traceConf, found = popTraceMapping(connHandle)
|
||||
} else {
|
||||
traceConf, found = lookupTraceMapping(connHandle)
|
||||
}
|
||||
|
||||
if !found {
|
||||
panic(fmt.Sprintf("Mapping not found for handle 0x%x (ev 0x%x)",
|
||||
connHandle, eventCode))
|
||||
}
|
||||
|
||||
var info TraceInfo
|
||||
|
||||
info.EventCode = eventCode
|
||||
info.AutoCommit = (int(C.sqlite3_get_autocommit(contextDB)) != 0)
|
||||
info.ConnHandle = connHandle
|
||||
|
||||
switch eventCode {
|
||||
case TraceStmt:
|
||||
info.StmtHandle = uintptr(p)
|
||||
|
||||
var xStr string
|
||||
if xValue != nil {
|
||||
xStr = C.GoString((*C.char)(xValue))
|
||||
}
|
||||
info.StmtOrTrigger = xStr
|
||||
if !strings.HasPrefix(xStr, "--") {
|
||||
// Not SQL comment, therefore the current event
|
||||
// is not related to a trigger.
|
||||
// The user might want to receive the expanded SQL;
|
||||
// let's check:
|
||||
if traceConf.WantExpandedSQL {
|
||||
fillExpandedSQL(&info, contextDB, p)
|
||||
}
|
||||
}
|
||||
|
||||
case TraceProfile:
|
||||
info.StmtHandle = uintptr(p)
|
||||
|
||||
if xValue == nil {
|
||||
panic("NULL pointer in X arg of trace_v2 callback for SQLITE_TRACE_PROFILE event")
|
||||
}
|
||||
|
||||
info.RunTimeNanosec = *(*int64)(xValue)
|
||||
|
||||
// sample the error //TODO: is it safe? is it useful?
|
||||
fillDBError(&info.DBError, contextDB)
|
||||
|
||||
case TraceRow:
|
||||
info.StmtHandle = uintptr(p)
|
||||
|
||||
case TraceClose:
|
||||
handle := uintptr(p)
|
||||
if handle != info.ConnHandle {
|
||||
panic(fmt.Sprintf("Different conn handle 0x%x (expected 0x%x) in SQLITE_TRACE_CLOSE event.",
|
||||
handle, info.ConnHandle))
|
||||
}
|
||||
|
||||
default:
|
||||
// Pass unsupported events to the user callback (if configured);
|
||||
// let the user callback decide whether to panic or ignore them.
|
||||
}
|
||||
|
||||
// Do not execute user callback when the event was not requested by user!
|
||||
// Remember that the Close event is always selected when
|
||||
// registering this callback trampoline with SQLite --- for cleanup.
|
||||
// In the future there may be more events forced to "selected" in SQLite
|
||||
// for the driver's needs.
|
||||
if traceConf.EventMask&eventCode == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
r := 0
|
||||
if traceConf.Callback != nil {
|
||||
r = traceConf.Callback(info)
|
||||
}
|
||||
return C.int(r)
|
||||
}
|
||||
|
||||
type traceMapEntry struct {
|
||||
config TraceConfig
|
||||
}
|
||||
|
||||
var traceMapLock sync.Mutex
|
||||
var traceMap = make(map[uintptr]traceMapEntry)
|
||||
|
||||
func addTraceMapping(connHandle uintptr, traceConf TraceConfig) {
|
||||
traceMapLock.Lock()
|
||||
defer traceMapLock.Unlock()
|
||||
|
||||
oldEntryCopy, found := traceMap[connHandle]
|
||||
if found {
|
||||
panic(fmt.Sprintf("Adding trace config %v: handle 0x%x already registered (%v).",
|
||||
traceConf, connHandle, oldEntryCopy.config))
|
||||
}
|
||||
traceMap[connHandle] = traceMapEntry{config: traceConf}
|
||||
fmt.Printf("Added trace config %v: handle 0x%x.\n", traceConf, connHandle)
|
||||
}
|
||||
|
||||
func lookupTraceMapping(connHandle uintptr) (TraceConfig, bool) {
|
||||
traceMapLock.Lock()
|
||||
defer traceMapLock.Unlock()
|
||||
|
||||
entryCopy, found := traceMap[connHandle]
|
||||
return entryCopy.config, found
|
||||
}
|
||||
|
||||
// 'pop' = get and delete from map before returning the value to the caller
|
||||
func popTraceMapping(connHandle uintptr) (TraceConfig, bool) {
|
||||
traceMapLock.Lock()
|
||||
defer traceMapLock.Unlock()
|
||||
|
||||
entryCopy, found := traceMap[connHandle]
|
||||
if found {
|
||||
delete(traceMap, connHandle)
|
||||
fmt.Printf("Pop handle 0x%x: deleted trace config %v.\n", connHandle, entryCopy.config)
|
||||
}
|
||||
return entryCopy.config, found
|
||||
}
|
||||
|
||||
// SetTrace installs or removes the trace callback for the given database connection.
|
||||
// It's not named 'RegisterTrace' because only one callback can be kept and called.
|
||||
// Calling SetTrace a second time on same database connection
|
||||
// overrides (cancels) any prior callback and all its settings:
|
||||
// event mask, etc.
|
||||
func (c *SQLiteConn) SetTrace(requested *TraceConfig) error {
|
||||
connHandle := uintptr(unsafe.Pointer(c.db))
|
||||
|
||||
_, _ = popTraceMapping(connHandle)
|
||||
|
||||
if requested == nil {
|
||||
// The traceMap entry was deleted already by popTraceMapping():
|
||||
// can disable all events now, no need to watch for TraceClose.
|
||||
err := c.setSQLiteTrace(0)
|
||||
return err
|
||||
}
|
||||
|
||||
reqCopy := *requested
|
||||
|
||||
// Disable potentially expensive operations
|
||||
// if their result will not be used. We are doing this
|
||||
// just in case the caller provided nonsensical input.
|
||||
if reqCopy.EventMask&TraceStmt == 0 {
|
||||
reqCopy.WantExpandedSQL = false
|
||||
}
|
||||
|
||||
addTraceMapping(connHandle, reqCopy)
|
||||
|
||||
// The callback trampoline function does cleanup on Close event,
|
||||
// regardless of the presence or absence of the user callback.
|
||||
// Therefore it needs the Close event to be selected:
|
||||
actualEventMask := uint(reqCopy.EventMask | TraceClose)
|
||||
err := c.setSQLiteTrace(actualEventMask)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SQLiteConn) setSQLiteTrace(sqliteEventMask uint) error {
|
||||
rv := C.sqlite3_trace_v2(c.db,
|
||||
C.uint(sqliteEventMask),
|
||||
(*[0]byte)(unsafe.Pointer(C.traceCallbackTrampoline)),
|
||||
unsafe.Pointer(c.db)) // Fourth arg is same as first: we are
|
||||
// passing the database connection handle as callback context.
|
||||
|
||||
if rv != C.SQLITE_OK {
|
||||
return c.lastError()
|
||||
}
|
||||
return nil
|
||||
}
|
57
vendor/github.com/mattn/go-sqlite3/sqlite3_type.go
generated
vendored
57
vendor/github.com/mattn/go-sqlite3/sqlite3_type.go
generated
vendored
@ -1,57 +0,0 @@
|
||||
package sqlite3
|
||||
|
||||
/*
|
||||
#ifndef USE_LIBSQLITE3
|
||||
#include <sqlite3-binding.h>
|
||||
#else
|
||||
#include <sqlite3.h>
|
||||
#endif
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ColumnTypeDatabaseTypeName implement RowsColumnTypeDatabaseTypeName.
|
||||
func (rc *SQLiteRows) ColumnTypeDatabaseTypeName(i int) string {
|
||||
return C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i)))
|
||||
}
|
||||
|
||||
/*
|
||||
func (rc *SQLiteRows) ColumnTypeLength(index int) (length int64, ok bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (rc *SQLiteRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
|
||||
return 0, 0, false
|
||||
}
|
||||
*/
|
||||
|
||||
// ColumnTypeNullable implement RowsColumnTypeNullable.
|
||||
func (rc *SQLiteRows) ColumnTypeNullable(i int) (nullable, ok bool) {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// ColumnTypeScanType implement RowsColumnTypeScanType.
|
||||
func (rc *SQLiteRows) ColumnTypeScanType(i int) reflect.Type {
|
||||
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
|
||||
case C.SQLITE_INTEGER:
|
||||
switch C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))) {
|
||||
case "timestamp", "datetime", "date":
|
||||
return reflect.TypeOf(time.Time{})
|
||||
case "boolean":
|
||||
return reflect.TypeOf(false)
|
||||
}
|
||||
return reflect.TypeOf(int64(0))
|
||||
case C.SQLITE_FLOAT:
|
||||
return reflect.TypeOf(float64(0))
|
||||
case C.SQLITE_BLOB:
|
||||
return reflect.SliceOf(reflect.TypeOf(byte(0)))
|
||||
case C.SQLITE_NULL:
|
||||
return reflect.TypeOf(nil)
|
||||
case C.SQLITE_TEXT:
|
||||
return reflect.TypeOf("")
|
||||
}
|
||||
return reflect.SliceOf(reflect.TypeOf(byte(0)))
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user