Add authorization to digest-score
This commit is contained in:
@ -16,13 +16,16 @@ from urllib.parse import urlparse
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import dateutil.parser
|
import dateutil.parser
|
||||||
from google.cloud import storage
|
from google.cloud import storage
|
||||||
from features.user_history import FEATURE_COLUMNS
|
|
||||||
|
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from threading import Lock, RLock
|
from threading import Lock, RLock
|
||||||
from collections import ChainMap
|
from collections import ChainMap
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
from features.user_history import FEATURE_COLUMNS
|
||||||
|
from auth import user_token_required, admin_token_required
|
||||||
|
|
||||||
|
|
||||||
class ThreadSafeUserFeatures:
|
class ThreadSafeUserFeatures:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._data = {}
|
self._data = {}
|
||||||
@ -219,12 +222,14 @@ def metrics():
|
|||||||
|
|
||||||
|
|
||||||
@app.route('/refresh', methods=['GET'])
|
@app.route('/refresh', methods=['GET'])
|
||||||
|
@admin_token_required
|
||||||
def refresh():
|
def refresh():
|
||||||
refresh_data()
|
refresh_data()
|
||||||
return jsonify({'OK': 'yes'}), 200
|
return jsonify({'OK': 'yes'}), 200
|
||||||
|
|
||||||
|
|
||||||
@app.route('/users/<user_id>/features', methods=['GET'])
|
@app.route('/users/<user_id>/features', methods=['GET'])
|
||||||
|
@admin_token_required
|
||||||
def get_user_features(user_id):
|
def get_user_features(user_id):
|
||||||
result = {}
|
result = {}
|
||||||
df_user = pd.DataFrame([{
|
df_user = pd.DataFrame([{
|
||||||
@ -243,12 +248,13 @@ def get_user_features(user_id):
|
|||||||
|
|
||||||
|
|
||||||
@app.route('/predict', methods=['POST'])
|
@app.route('/predict', methods=['POST'])
|
||||||
|
@user_token_required
|
||||||
def predict():
|
def predict():
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
app.logger.info(f"predict scoring request: {data}")
|
app.logger.info(f"predict scoring request: {data}")
|
||||||
|
|
||||||
user_id = data.get('user_id')
|
user_id = request.user_id
|
||||||
item_features = data.get('item_features')
|
item_features = data.get('item_features')
|
||||||
|
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
@ -263,14 +269,13 @@ def predict():
|
|||||||
|
|
||||||
|
|
||||||
@app.route('/batch', methods=['POST'])
|
@app.route('/batch', methods=['POST'])
|
||||||
|
@user_token_required
|
||||||
def batch():
|
def batch():
|
||||||
start = timer()
|
start = timer()
|
||||||
try:
|
try:
|
||||||
data = request.get_json()
|
data = request.get_json()
|
||||||
app.logger.info(f"batch scoring request: {data}")
|
|
||||||
|
|
||||||
items = data.get('items')
|
items = data.get('items')
|
||||||
user_id = data.get('user_id')
|
user_id = request.user_id
|
||||||
if user_id == None:
|
if user_id == None:
|
||||||
return jsonify({'error': 'no user_id supplied'}), 400
|
return jsonify({'error': 'no user_id supplied'}), 400
|
||||||
if len(items) > 101:
|
if len(items) > 101:
|
||||||
|
|||||||
45
ml/digest-score/auth.py
Normal file
45
ml/digest-score/auth.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import os
|
||||||
|
import jwt
|
||||||
|
from flask import request, jsonify
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
SECRET_KEY = os.getenv('JWT_SECRET')
|
||||||
|
ADMIN_SECRET_KEY = os.getenv('JWT_ADMIN_SECRET_KEY')
|
||||||
|
|
||||||
|
def user_token_required(f):
|
||||||
|
@wraps(f)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
token = None
|
||||||
|
if 'Authorization' in request.headers:
|
||||||
|
print("request.headers['Authorization'].split(" ")[1]", request.headers['Authorization'].split(" ")[1])
|
||||||
|
token = request.headers['Authorization'].split(" ")[1]
|
||||||
|
if not token:
|
||||||
|
return jsonify({'message': 'Token is missing!'}), 401
|
||||||
|
try:
|
||||||
|
data = jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
|
||||||
|
request.user_id = data['uid']
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
return jsonify({'message': 'Token has expired!'}), 401
|
||||||
|
except jwt.InvalidTokenError:
|
||||||
|
return jsonify({'message': 'Token is invalid!'}), 401
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
def admin_token_required(f):
|
||||||
|
@wraps(f)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
token = None
|
||||||
|
if 'Authorization' in request.headers:
|
||||||
|
token = request.headers['Authorization'].split(" ")[1]
|
||||||
|
if not token:
|
||||||
|
return jsonify({'message': 'Token is missing!'}), 401
|
||||||
|
try:
|
||||||
|
data = jwt.decode(token, ADMIN_SECRET_KEY, algorithms=["HS256"])
|
||||||
|
if data['role'] != 'admin':
|
||||||
|
return jsonify({'message': 'Admin token required!'}), 403
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
return jsonify({'message': 'Token has expired!'}), 401
|
||||||
|
except jwt.InvalidTokenError:
|
||||||
|
return jsonify({'message': 'Token is invalid!'}), 401
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
return decorated
|
||||||
@ -9,5 +9,6 @@ sklearn2pmml
|
|||||||
sqlalchemy
|
sqlalchemy
|
||||||
pyarrow
|
pyarrow
|
||||||
requests
|
requests
|
||||||
|
PyJWT
|
||||||
prometheus_client
|
prometheus_client
|
||||||
xgboost==2.1.0
|
xgboost==2.1.0
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import client from 'prom-client'
|
|||||||
import { env } from '../env'
|
import { env } from '../env'
|
||||||
import { registerMetric } from '../prometheus'
|
import { registerMetric } from '../prometheus'
|
||||||
import { logError } from '../utils/logger'
|
import { logError } from '../utils/logger'
|
||||||
|
import { createWebAuthToken } from '../routers/auth/jwt_helpers'
|
||||||
|
|
||||||
export interface Feature {
|
export interface Feature {
|
||||||
library_item_id?: string
|
library_item_id?: string
|
||||||
@ -77,10 +78,12 @@ class ScoreClientImpl implements ScoreClient {
|
|||||||
|
|
||||||
async getScores(data: ScoreApiRequestBody): Promise<ScoreApiResponse> {
|
async getScores(data: ScoreApiRequestBody): Promise<ScoreApiResponse> {
|
||||||
const start = Date.now()
|
const start = Date.now()
|
||||||
|
const authToken = createWebAuthToken(data.user_id)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const response = await axios.post<ScoreApiResponse>(this.apiUrl, data, {
|
const response = await axios.post<ScoreApiResponse>(this.apiUrl, data, {
|
||||||
headers: {
|
headers: {
|
||||||
|
Authorization: `Bearer ${authToken}`,
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
},
|
},
|
||||||
timeout: 5000,
|
timeout: 5000,
|
||||||
|
|||||||
Reference in New Issue
Block a user