Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 120 additions & 63 deletions annexremote/annexremote.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class SpecialRemote(metaclass=ABCMeta):
Note that the user is not required to provided all the settings listed here.
"""

def __init__(self, annex):
def __init__(self, annex=None):
self.annex = annex
self.info = {}
self.configs = {}
Expand Down Expand Up @@ -522,7 +522,14 @@ def renameexport(self, key, filename, new_filename):
"""
raise UnsupportedRequest()

class Protocol(object):
class AsyncRemote:
def __init__(self):
self.async_supported = True

def getRemoteJob(self, job_number):
raise NotImplementedError

class Protocol:
"""
Helper class handling the receiving part of the protocol (git-annex to remote)
It parses the requests coming from git-annex and calls the respective
Expand All @@ -531,11 +538,14 @@ class Protocol(object):
It is not further documented as it was never intended to be part of the public API.
"""

def __init__(self, remote):
def __init__(self, remote, master):
self.remote = remote
self.request_messages = _GitAnnexRequestMessages(self, self.remote)
self.master = master
self.version = "VERSION 1"
self.exporting = False
self.extensions = list()
self.jobs = {}

def command(self, line):
line = line.strip()
Expand All @@ -544,7 +554,7 @@ def command(self, line):
raise ProtocolError("Got empty line")


method = self.lookupMethod(parts[0]) or self.do_UNKNOWN
method = self.lookupMethod(parts[0])


try:
Expand All @@ -555,14 +565,63 @@ def command(self, line):
except TypeError as e:
raise SyntaxError(e)
else:
if method != self.do_EXPORT:
if method != self.request_messages.do_EXPORT:
self.exporting = False
return reply

def lookupMethod(self, command):
return getattr(self, 'do_' + command.upper(), None)
return getattr(self.request_messages, 'do_' + command.upper(), self.request_messages.do_UNKNOWN)


def get_job(self, job_number: int) -> "Protocol":
if job_number not in self.jobs:
self.jobs[job_number] = Protocol(self.remote.getRemoteJob(job_number))
return self.jobs[job_number]

def error(self, *args):
self._send("ERROR", *args)

def check_key(self, key):
def debug(self, *args):
self._send("DEBUG", *args)

def _ask(self, request, reply_keyword, reply_count):
self._send(request)
line = self.master.input.readline().rstrip().split(" ", reply_count)
if line and line[0] == reply_keyword:
line.extend([""] * (reply_count+1-len(line)))
return line[1:]
else:
raise UnexpectedMessage("Expected {reply_keyword} and {reply_count} values. Got {line}".format(reply_keyword=reply_keyword, reply_count=reply_count, line=line))

def _askvalues(self, request):
self._send(request)
reply = []
while True:
# due to a bug in python 2 we can't use an iterator here: https://bugs.python.org/issue1633941
line = self.master.input.readline()
line = line.rstrip()
line = line.split(" ", 1)
if len(line) == 2 and line[0] == "VALUE":
reply.append(line[1])
elif len(line) == 1 and line[0] == "VALUE":
return reply
else:
raise UnexpectedMessage("Expected VALUE {value}")

def _askvalue(self, request):
(reply,) = self._ask(request, "VALUE", 1)
return reply

def _send(self, *args, **kwargs):
print(*args, file=self.master.output, **kwargs)
self.master.output.flush()

class _GitAnnexRequestMessages(object):
def __init__(self, protocol, remote):
self.protocol = protocol
self.remote = remote

def _check_key(self, key):
if len(key.split()) != 1:
raise ValueError("Invalid key. Key contains whitespace character")

Expand All @@ -578,9 +637,20 @@ def do_INITREMOTE(self):
return "INITREMOTE-SUCCESS"

def do_EXTENSIONS(self, param):
self.extensions = param.split(" ")
return "EXTENSIONS"
self.protocol.extensions = param.split(" ")
remote_extensions = []
if hasattr(self.remote, "async_support") and \
self.remote.async_support == True:
remote_extensions.append("ASYNC")
return ' '.join(["EXTENSIONS"] + remote_extensions)

def do_J(self, param):
try:
(job_number, command) = param.split(" ", 1)
except ValueError:
raise SyntaxError("Expected Jobnumber Command")
return self.protocol.get_job(job_number).command(command)

def do_PREPARE(self):
try:
self.remote.prepare()
Expand All @@ -607,7 +677,7 @@ def do_TRANSFER(self, param):
return "TRANSFER-SUCCESS {method} {key}".format(method=method, key=key)

def do_CHECKPRESENT(self, key):
self.check_key(key)
self._check_key(key)
try:
if self.remote.checkpresent(key):
return "CHECKPRESENT-SUCCESS {key}".format(key=key)
Expand All @@ -617,7 +687,7 @@ def do_CHECKPRESENT(self, key):
return "CHECKPRESENT-UNKNOWN {key} {e}".format(key=key, e=e)

def do_REMOVE(self, key):
self.check_key(key)
self._check_key(key)

try:
self.remote.remove(key)
Expand Down Expand Up @@ -695,7 +765,7 @@ def do_CHECKURL(self, url):


def do_WHEREIS(self, key):
self.check_key(key)
self._check_key(key)
reply = self.remote.whereis(key)
if reply:
return "WHEREIS-SUCCESS {reply}".format(reply=reply)
Expand All @@ -721,10 +791,10 @@ def do_EXPORTSUPPORTED(self):
return "EXPORTSUPPORTED-FAILURE"

def do_EXPORT(self, name):
self.exporting = name
self.protocol.exporting = name

def do_TRANSFEREXPORT(self, param):
if not self.exporting:
if not self.protocol.exporting:
raise ProtocolError("Export request without prior EXPORT")
try:
(method, key, file_) = param.split(" ", 2)
Expand All @@ -736,31 +806,31 @@ def do_TRANSFEREXPORT(self, param):

func = getattr(self.remote, "transferexport_{}".format(method.lower()), None)
try:
func(key, file_, self.exporting)
func(key, file_, self.protocol.exporting)
except RemoteError as e:
return "TRANSFER-FAILURE {method} {key} {e}".format(method=method, key=key, e=e)
else:
return "TRANSFER-SUCCESS {method} {key}".format(method=method, key=key)

def do_CHECKPRESENTEXPORT(self, key):
if not self.exporting:
if not self.protocol.exporting:
raise ProtocolError("Export request without prior EXPORT")
self.check_key(key)
self._check_key(key)
try:
if self.remote.checkpresentexport(key, self.exporting):
if self.remote.checkpresentexport(key, self.protocol.exporting):
return "CHECKPRESENT-SUCCESS {key}".format(key=key)
else:
return "CHECKPRESENT-FAILURE {key}".format(key=key)
except RemoteError as e:
return "CHECKPRESENT-UNKNOWN {key} {e}".format(key=key, e=e)

def do_REMOVEEXPORT(self, key):
if not self.exporting:
if not self.protocol.exporting:
raise ProtocolError("Export request without prior EXPORT")
self.check_key(key)
self._check_key(key)

try:
self.remote.removeexport(key, self.exporting)
self.remote.removeexport(key, self.protocol.exporting)
except RemoteError as e:
return "REMOVE-FAILURE {key} {e}".format(key=key, e=e)
else:
Expand All @@ -775,15 +845,15 @@ def do_REMOVEEXPORTDIRECTORY(self, name):
return "REMOVEEXPORTDIRECTORY-SUCCESS"

def do_RENAMEEXPORT(self, param):
if not self.exporting:
if not self.protocol.exporting:
raise ProtocolError("Export request without prior EXPORT")
try:
(key, new_name) = param.split(None, 1)
except ValueError:
raise SyntaxError("Expected TRANSFER STORE Key File")

try:
self.remote.renameexport(key, self.exporting, new_name)
self.remote.renameexport(key, self.protocol.exporting, new_name)
except RemoteError:
return "RENAMEEXPORT-FAILURE {key}".format(key=key)
else:
Expand Down Expand Up @@ -819,6 +889,7 @@ def __init__(self, output=sys.stdout):
Default: sys.stdout
"""
self.output = output
self.input = sys.stdin

def LinkRemote(self, remote):
"""
Expand All @@ -831,7 +902,8 @@ def LinkRemote(self, remote):
ExternalSpecialRemote interface to which this master will be linked.
"""
self.remote = remote
self.protocol = Protocol(remote)
self.protocol = Protocol(remote, self)
self.remote.annex = SpecialRemoteMessages(self.protocol)

def LoggingHandler(self):
"""
Expand All @@ -841,7 +913,7 @@ def LoggingHandler(self):
-------
AnnexLoggingHandler
"""
return AnnexLoggingHandler(self)
return AnnexLoggingHandler(self.protocol)

def Listen(self, input=sys.stdin):
"""
Expand All @@ -862,7 +934,7 @@ def Listen(self, input=sys.stdin):
raise NotLinkedError("Please execute LinkRemote(remote) first.")

self.input = input
self._send(self.protocol.version)
self.protocol._send(self.protocol.version)
while True:
# due to a bug in python 2 we can't use an iterator here: https://bugs.python.org/issue1633941
line = self.input.readline()
Expand All @@ -872,43 +944,32 @@ def Listen(self, input=sys.stdin):
try:
reply = self.protocol.command(line)
if reply:
self._send(reply)
self.protocol._send(reply)
except UnsupportedRequest:
self._send ("UNSUPPORTED-REQUEST")
self.protocol._send ("UNSUPPORTED-REQUEST")
except Exception as e:
for line in traceback.format_exc().splitlines():
self.debug(line)
self.error(e)
self.protocol.debug(line)
self.protocol.error(e)
raise SystemExit

def _ask(self, request, reply_keyword, reply_count):
self._send(request)
line = self.input.readline().rstrip().split(" ", reply_count)
if line and line[0] == reply_keyword:
line.extend([""] * (reply_count+1-len(line)))
return line[1:]
else:
raise UnexpectedMessage("Expected {reply_keyword} and {reply_count} values. Got {line}".format(reply_keyword=reply_keyword, reply_count=reply_count, line=line))

def _askvalues(self, request):
self._send(request)
reply = []
while True:
# due to a bug in python 2 we can't use an iterator here: https://bugs.python.org/issue1633941
line = self.input.readline()
line = line.rstrip()
line = line.split(" ", 1)
if len(line) == 2 and line[0] == "VALUE":
reply.append(line[1])
elif len(line) == 1 and line[0] == "VALUE":
return reply
else:
raise UnexpectedMessage("Expected VALUE {value}")
class SpecialRemoteMessages:
def __init__(self, protocol):
self.protocol = protocol

def _ask(self, *args, **kwargs):
return self.protocol._ask(*args, **kwargs)

def _askvalues(self, *args, **kwargs):
return self.protocol._askvalues(*args, **kwargs)

def _askvalue(self, *args, **kwargs):
return self.protocol._askvalue(*args, **kwargs)

def _send(self, *args, **kwargs):
return self.protocol._send(*args, **kwargs)

def _askvalue(self, request):
(reply,) = self._ask(request, "VALUE", 1)
return reply

def getconfig(self, setting):
"""
Gets one of the special remote's configuration settings,
Expand Down Expand Up @@ -999,7 +1060,7 @@ def debug(self, *args):
The message to be displayed to the user
"""

self._send("DEBUG", *args)
self.protocol.debug(*args)

def error(self, *args):
"""
Expand All @@ -1013,7 +1074,7 @@ def error(self, *args):
error_msg : str
The error message to be sent to git-annex
"""
self._send("ERROR", *args)
self.protocol.error(*args)

def progress(self, progress):
"""
Expand Down Expand Up @@ -1323,8 +1384,4 @@ def getgitremotename(self):
return self._askvalue("GETGITREMOTENAME")
else:
raise ProtocolError("GETGITREMOTENAME not available")


def _send(self, *args, **kwargs):
print(*args, file=self.output, **kwargs)
self.output.flush()

Loading